Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Use ChainRules #189

Merged
merged 71 commits into from
Jul 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
32f951a
Add ChainRules to Project.toml
oxinabox Oct 21, 2020
83d24c5
Load all rules from ChainRules
oxinabox Sep 3, 2020
563b0a6
Make Branch parametric on the pullback type
oxinabox Oct 14, 2020
5db5bd3
Don't by pass the tape abstraction in propagate
oxinabox Oct 14, 2020
c1d383a
use update! to handle InplaceableThunks
oxinabox Sep 7, 2020
23e9a64
unthunk public API
oxinabox Sep 17, 2020
69edf48
Remove all the array rules that are unneeded.
oxinabox Sep 17, 2020
fb45de7
For now keep diffrules for special nosensitivity scalar case so as no…
oxinabox Sep 18, 2020
5369e2d
Correct tests re #191
oxinabox Sep 18, 2020
7b30b1c
fix scalar-array \ and / tests
oxinabox Sep 30, 2020
29ff45d
Delete generic linear algebra that has moved to ChainRules
oxinabox Sep 30, 2020
04b4582
Delete Diagonal methods that moved to ChainRules
oxinabox Sep 30, 2020
0709260
Delete BLAS rules that moved to ChainRules
oxinabox Sep 30, 2020
08d9a2e
Use ChainRules for Cholesky
oxinabox Oct 1, 2020
2aeef68
move SVD to ChainRules
oxinabox Oct 1, 2020
b71b6d0
Correct ∇ for Symmetric constructor
oxinabox Oct 12, 2020
835f66e
Move strided to ChainRules.jl
oxinabox Oct 12, 2020
bd28593
Delete moved rule for diagonal
oxinabox Oct 12, 2020
1b0df5b
remove structured constructors that moved to ChainRules
oxinabox Oct 13, 2020
78c4c41
delete triangular rules that moved to ChainRules
oxinabox Oct 13, 2020
7aad58a
put stuff in testsets
oxinabox Oct 13, 2020
ae99591
move list of linalg optimizations to the tests
oxinabox Oct 14, 2020
a6ede28
delete never used uniform scaling file
oxinabox Oct 14, 2020
e05bbba
Move indexing over to ChainRules
oxinabox Oct 16, 2020
1549b75
Remove DiffRules entirely
oxinabox Oct 19, 2020
c0a7d81
remove testing scratch file
oxinabox Oct 21, 2020
a16a64b
stop tracking ExprTools as a submodule
oxinabox Oct 21, 2020
fd3c497
make reduce tests use the list of UNITARY sensitivities
oxinabox Oct 21, 2020
c6b4d36
sortout Project.toml
oxinabox Oct 21, 2020
f2c309c
remove import of removed partition functions from tests
oxinabox Oct 21, 2020
773908f
make tests not overwrite functions
oxinabox Oct 21, 2020
2d84b3e
drop support for Special Functions 0.9
oxinabox Oct 21, 2020
e755435
Add docstrings and seperate original_sig from unonized_sig
oxinabox Oct 21, 2020
0fce85b
Apply suggestions from code review
oxinabox Oct 22, 2020
6c1c4df
Apply suggestions from code review
oxinabox Oct 22, 2020
bdd3e43
move deciding what rules to use into its own function and docstring
oxinabox Oct 22, 2020
b17190c
Specific error
oxinabox Oct 23, 2020
ba2670c
Apply suggestions from code review
oxinabox Oct 23, 2020
8e897b9
Support Special Function 0.9
oxinabox Nov 2, 2020
ff44f9c
Apply suggestions from code review
oxinabox Nov 2, 2020
20db255
Remove directly applied (commented out) broadcasting tests
oxinabox Nov 2, 2020
67455fe
Allow SpecialFunction 0.8 for julia 1.0
oxinabox Nov 13, 2020
5151ad8
Delete redundant rule for identity
oxinabox Nov 13, 2020
04f9b58
remove mistakenly added SnoopCompile dependency
oxinabox Nov 13, 2020
859e0a2
Fix comments
oxinabox Dec 4, 2020
fc7fb5a
Don't generate rules for a bunch of nondifferentiable things that cau…
oxinabox Nov 13, 2020
c27a8fd
correct _truly_rename_unionall spelling
oxinabox Dec 4, 2020
6d0559b
handle remove varargs with redundant N and other typevars
oxinabox Dec 4, 2020
bd4271f
remove iunneded variable
oxinabox Dec 4, 2020
251ab3b
filter out nonfields inplace
oxinabox Dec 4, 2020
c470871
split up BINARY_SCALAR_SENSITIVITIES
oxinabox Dec 8, 2020
9d7e87b
Remove last of the varient ȳ
oxinabox Dec 8, 2020
f11e7eb
fix comment typo
oxinabox Mar 1, 2021
a5f758a
typos in comments/docs
oxinabox Mar 1, 2021
ff8cf0d
Update for new ChainRulesCore
oxinabox Jun 24, 2021
24f52bd
Make tests of identity not try and use the version of ∇ that requires…
oxinabox Jun 24, 2021
f0c05b4
Stop testing equality when approximate equality is better
oxinabox Jun 24, 2021
4f5241c
remove matrix exp which is now in ChainRules
oxinabox Jun 24, 2021
64502d1
Change test to reflect that asymmetric matrix expodential now works
oxinabox Jun 25, 2021
4afefc0
update now the SVD.Vt now works
oxinabox Jun 25, 2021
1da2489
Link to chainrules hooks outside of __init__ for MUCH faster loading
oxinabox Jun 25, 2021
ce65be8
broadcast_axes is now axes and we want to ignore ChainRules def for it
oxinabox Jul 2, 2021
55877bc
Move helpers to ExprTools 1.4
oxinabox Jul 2, 2021
5134063
use collect not float.
oxinabox Jul 2, 2021
09fe3e4
delete strided and its tests which are wrong and it is tested in Chai…
oxinabox Jul 2, 2021
f2fe4cb
tag as a breaking change
oxinabox Jul 2, 2021
c2b9edc
delete old comment
oxinabox Jul 2, 2021
d6328b4
Block using ChainRules for rules remaining in Nabla
oxinabox Jul 2, 2021
d37b487
Update CI to match current prod min julia version
oxinabox Jul 2, 2021
00ef4ac
block one and zero
oxinabox Jul 2, 2021
970893a
Update docs to be compatible with ChainRules and document about using…
oxinabox Jul 2, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ jobs:
- os: windows-latest
arch: x86
include:
# Add a 1.5 job because that's what Invenia actually uses
# Add a 1.6 job because that's what Invenia actually uses
- os: ubuntu-latest
version: 1.5
version: 1.6
arch: x64
steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -81,7 +81,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: '1'
version: '1.6'
- run: |
julia --project=docs -e '
using Pkg
Expand Down
10 changes: 7 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
*.pdf
*.DS_Store
*.jl.cov
*.jl.*.cov
*.jl.mem
*.pdf
*.DS_Store
Manifest.toml
docs/build/
docs/build
docs/site
docs/src/assets/chainrules.css
docs/src/assets/indigo.css
.vscode/settings.json
21 changes: 14 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
name = "Nabla"
uuid = "49c96f43-aa6d-5a04-a506-44c7070ebe78"
version = "0.12.3"
version = "0.13.0"

[deps]
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
DiffRules = "0.0, 1"
DualNumbers = "0.6"
FDM = "^0.6"
SpecialFunctions = "0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1"
ChainRules = "0.8"
ChainRulesCore = "0.10.9"
ChainRulesOverloadGeneration = "0.1.2"
ExprTools = "0.1.4"
FDM = "0.6.1"
ForwardDiff = "0.10.12"
SpecialFunctions = "1.5.1"
julia = "^1.3"

[extras]
Expand Down
4 changes: 3 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
DocThemeIndigo = "8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Nabla = "49c96f43-aa6d-5a04-a506-44c7070ebe78"

[compat]
Documenter = "~0.19"
Documenter = "0.27"
25 changes: 13 additions & 12 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
using Documenter, Nabla
using Documenter
using DocThemeIndigo
using Nabla

const indigo = DocThemeIndigo.install(Nabla)
makedocs(
modules=[Nabla],
format=:html,
format=Documenter.HTML(
prettyurls=false,
assets=[indigo],
),
sitename="Nabla.jl",
authors="Invenia Labs",
pages=[
"Home" => "index.md",
"API" => "pages/api.md",
"Custom Sensitivities" => "pages/custom.md",
"Details" => "pages/autodiff.md",
],
sitename="Nabla.jl",
authors="Invenia Labs",
assets=[
"assets/invenia.css",
],
)


deploydocs(
repo = "github.com/invenia/Nabla.jl.git",
julia = "1.0",
target = "build",
deps = nothing,
make = nothing,
)
push_preview=true,
)
75 changes: 0 additions & 75 deletions docs/src/assets/invenia.css

This file was deleted.

Binary file added docs/src/assets/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 16 additions & 1 deletion docs/src/pages/custom.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
# Custom Sensitivities
# Custom Sensitivities

!!! note "Prefer to use ChainRulesCore to define custom sensitivities"
Nabla supports the use of [ChainRulesCore](http://www.juliadiff.org/ChainRulesCore.jl/stable/) to define custom sensitivities.
It is preferred to define the custom sensitivities using `ChainRulesCore.rrule` as they will work for many AD systems, not just Nabla.
**It is also much easier, than the Nabla specific way**.
These sensitivities can be added in your own package, or for Base/StdLib functions they can be added to [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/).
To define custom sensitivities using ChainRulesCore, define a `ChainRulesCore.rrule(f, args...; kwargs...)`.
See the [ChainRules project's documentation for more information](https://www.juliadiff.org/ChainRulesCore.jl/stable/).
**If you are defining your custom sensitivities using ChainRulesCore then you do not need to read this page**, and can consider it as documenting a legacy feature.

This page exists to describe how Nabla works, and how sensitivities can be directly defined for Nabla.
Defining sensitivities this way does not make them accessible to other AD systems, but does let you do things that directly depend on how Nabla works.
It allows for specific definitions of sensitivities that are only defined for Nabla (which might work differently to more generic definitions defined for all AD).

# Legacy Method

Part of the power of Nabla is its extensibility, specifically in the form of defining
custom sensitivities for functions.
Expand Down
19 changes: 14 additions & 5 deletions src/Nabla.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
__precompile__()

module Nabla

using SpecialFunctions
using ChainRules
using ChainRulesCore
using ChainRulesOverloadGeneration
using ExprTools: ExprTools
using ForwardDiff: ForwardDiff
using LinearAlgebra
using Random
using SpecialFunctions
using Statistics

# Some aliases used repeatedly throughout the package.
Expand Down Expand Up @@ -39,10 +44,12 @@ module Nabla
# into a separate module at some point.
include("finite_differencing.jl")

# Sensitivities via ChainRules
include("sensitivities/chainrules.jl")

# Sensitivities for the basics.
include("sensitivities/indexing.jl")
include("sensitivities/scalar.jl")
include("sensitivities/array.jl")

# Sensitivities for functionals.
include("sensitivities/functional/functional.jl")
Expand All @@ -52,14 +59,16 @@ module Nabla
# Linear algebra optimisations.
include("sensitivities/linalg/generic.jl")
include("sensitivities/linalg/symmetric.jl")
include("sensitivities/linalg/strided.jl")
include("sensitivities/linalg/blas.jl")
include("sensitivities/linalg/diagonal.jl")
include("sensitivities/linalg/triangular.jl")
include("sensitivities/linalg/factorization/cholesky.jl")
include("sensitivities/linalg/factorization/svd.jl")

# Checkpointing
include("checkpointing.jl")


# Link up to ChainRulesCore so rules are generated when new rrules are declared.
on_new_rule(generate_overload, rrule)

end # module Nabla
37 changes: 36 additions & 1 deletion src/code_transformation/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ function unionise_type(tp::Union{Symbol, Expr})
return replace_vararg(:(Union{$_tp, Node{<:$tp_clean}}), (_tp, _info))
end

"""
node_type(tp::Union{Symbol, Expr})

Returns an expression for the `Node{<:tp}`. e.g.
`node_type(:Real)` returns `:(Node{<:Real}})`.

Correctly `:(Vararg{Real})` becomes `:(Vararg{Node{<:Real}})`

This is a lot like [`unionise_type`](ref) but it doesn't permit the original type anymore.
"""
function node_type(tp::Union{Symbol, Expr})
(_tp, _info) = remove_vararg(tp)
tp_clean = (isa(_tp, Expr) && _tp.head == Symbol("<:")) ? _tp.args[1] : _tp
return replace_vararg(:(Node{<:$tp_clean}), (_tp, _info))
end


"""
replace_body(unionall::Union{Symbol, Expr}, replacement::Union{Symbol, Expr})

Expand Down Expand Up @@ -91,6 +108,24 @@ function remove_vararg(typ::Expr)
if isa_vararg(typ)
body = get_body(typ)
new_typ = replace_body(typ, body.args[2])

# This is a bit ugly:
# handle interally `where N` from `typ = :(Vararg{FOO, N} where N)` which results in
# `body = :(Vararg{FOO, N})` and `new_type = Foo where N`, we don't need to keep it
# at all, the `where N` wasn't doing anything to begin with, so we just strip it out
if Meta.isexpr(new_typ, :where) && Meta.isexpr(body, :curly, 3)
@assert body.args[1] == :Vararg
T = body.args[2]
N = body.args[3]
if new_typ.args == [T, N] # ($T where $N)
body = :(Vararg{T})
new_typ = T
elseif T == new_typ.args[1] && N ∈ new_typ.args[2:end] # ($T where {?, $N, ?})
body = :(Vararg{T})
filter!(!isequal(N), new_typ.args)
end
end

vararg_info = length(body.args) == 3 ? body.args[3] : :Vararg
return new_typ, vararg_info
else
Expand All @@ -107,7 +142,7 @@ Convert `typ` to the `Vararg` containing elements of type `typ` specified by
replace_vararg(typ::SymOrExpr, vararg_info::Tuple) =
vararg_info[2] == :nothing ?
typ :
vararg_info[2] == :no_N || vararg_info[2] == :Vararg ?
vararg_info[2] == :no_N || vararg_info[2] == :Vararg ? #TODO: :no_N is impossible now?
replace_body(typ, :(Vararg{$(get_body(typ))})) :
replace_body(typ, :(Vararg{$(get_body(typ)), $(vararg_info[2])}))

Expand Down
Loading