Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
226 commits
Select commit Hold shift + click to select a range
be316ff
Start forward mode prototype
gdalle Nov 24, 2024
deac913
First working autodiff
gdalle Nov 24, 2024
9c96c8d
Docstring
gdalle Nov 24, 2024
136aff6
Apply suggestions from code review
gdalle Nov 24, 2024
f65cc53
Moving files around
gdalle Nov 24, 2024
053a8bb
Primitives already known
gdalle Nov 24, 2024
6d8ec04
Merge branch 'main' into gd/forward
gdalle Nov 25, 2024
a3107a8
Keep pushing forward (pun intended)
gdalle Nov 25, 2024
2836ac8
Still buggy, don't touch
gdalle Nov 25, 2024
09d63bd
Keep instruction mapping one to one
gdalle Nov 26, 2024
fa679eb
Use replace_call
gdalle Nov 26, 2024
a68257c
Ignore code cov
gdalle Nov 27, 2024
7a096ba
No Aqua piracies test
gdalle Nov 27, 2024
46c3e5a
Start control flow
gdalle Nov 28, 2024
ad3f98a
Fix intrinsic
gdalle Nov 28, 2024
9071574
Import
gdalle Nov 28, 2024
dcfe282
Typos
gdalle Nov 28, 2024
e44380d
Co-authored-by: Will Tebbutt <willtebbutt@users.noreply.github.com>
gdalle Dec 6, 2024
dd89e57
Figure out incremental additions
gdalle Dec 6, 2024
9bdb57f
Initial test case additions
willtebbutt Dec 6, 2024
4bb9911
Formatting
willtebbutt Dec 6, 2024
9b037e7
Add verify_dual_type
willtebbutt Dec 6, 2024
6dea624
test_frule_interface runs
willtebbutt Dec 6, 2024
a614846
Fix ReturnNode
willtebbutt Dec 6, 2024
eadae95
Correctness testing runs
willtebbutt Dec 6, 2024
345b3fd
Add randn_dual
willtebbutt Dec 6, 2024
f58c394
Improve sin and cos frules
willtebbutt Dec 6, 2024
c8d8895
Performance tests run
willtebbutt Dec 6, 2024
578e41b
Tidy up implementation
willtebbutt Dec 6, 2024
b5d34b2
Standard testing infrastructure
willtebbutt Dec 6, 2024
205e716
Fix typos
willtebbutt Dec 6, 2024
d328db0
Fix return node to return dual
gdalle Dec 6, 2024
66a48c8
Handle PiNode
gdalle Dec 6, 2024
e455cf6
Deleted line
gdalle Dec 6, 2024
8d120b2
Case 7 solved
gdalle Jan 27, 2025
cd7167f
Resolve merge conflict
willtebbutt Jan 27, 2025
c5ffae7
Fix precompile issue
willtebbutt Jan 27, 2025
94aa904
Fix isa rule
willtebbutt Jan 27, 2025
cc7a3fa
Fix is_primitive
willtebbutt Jan 27, 2025
70d7183
More test cases
gdalle Feb 4, 2025
aec412e
progress
gdalle Feb 6, 2025
0ea1084
fixes
gdalle Feb 7, 2025
d8a949f
Bump patch vesion
willtebbutt Feb 12, 2025
79844d2
Fix terminators
willtebbutt Feb 12, 2025
49aa4ca
Merge remote-tracking branch 'upstream/wct/fix-terminator-issue' into…
gdalle Feb 12, 2025
9ce99ec
More cases
gdalle Feb 12, 2025
6ce2488
More cases
gdalle Feb 12, 2025
8954361
Merge remote-tracking branch 'upstream/main' into gd/forward
gdalle Feb 14, 2025
941a2de
Merge remote-tracking branch 'upstream/main' into gd/forward
gdalle Feb 14, 2025
af49eac
Tuple rule
gdalle Feb 14, 2025
0b4e5fa
Merge in main
willtebbutt Mar 14, 2025
8204665
Formatting
willtebbutt Mar 14, 2025
70fec10
Code to view forwards-mode IR from a signature
willtebbutt Mar 14, 2025
6cde147
Use widenconst to get actual argtype from ircode argtypes
willtebbutt Mar 14, 2025
0eabff0
MyInstruction -> new_instruction
willtebbutt Mar 14, 2025
8b391c6
Formatting
willtebbutt Mar 15, 2025
5d6b826
Merge branch 'main' into gd/forward
willtebbutt Mar 17, 2025
a919a28
Various improvements
willtebbutt Mar 18, 2025
2808a12
Rules for foreigncalls
willtebbutt Mar 19, 2025
cb28759
Fix pointer tests with forwards mode
willtebbutt Mar 19, 2025
f9d1697
Enable more tests
willtebbutt Mar 19, 2025
9bc53cc
All derivation tests pass
willtebbutt Mar 19, 2025
d6fc35d
Initial pass over legacy array functionality
willtebbutt Mar 20, 2025
6b2409c
Fix tangent usage in tests
willtebbutt Mar 20, 2025
d6974c1
Rules for nice BLAS functions
willtebbutt Mar 20, 2025
fbcc6ce
Tweak test inputs slightly
willtebbutt Mar 20, 2025
732762b
Enable CI for BLAS and foreigncalls
willtebbutt Mar 20, 2025
fd48f02
Enable linear_algebra rules
willtebbutt Mar 20, 2025
f6bc752
More stuff works
willtebbutt Mar 21, 2025
a96611e
Make IdDict work
willtebbutt Mar 21, 2025
44e78b4
Code to identify SSA uses
willtebbutt Mar 21, 2025
a504413
Fix failing test via special case
willtebbutt Mar 21, 2025
f68c79b
Remove outdated TODO note
willtebbutt Mar 21, 2025
05cbb83
Merge branch 'main' into gd/forward
willtebbutt Mar 21, 2025
30c5294
Fix typo
willtebbutt Mar 21, 2025
fe4ec4a
BLAS support nearly finished
willtebbutt Mar 23, 2025
f771f70
All BLAS rules passing
willtebbutt Mar 24, 2025
86fa1b6
Initial work on getrf
willtebbutt Mar 25, 2025
04ea669
Merge branch 'main' into gd/forward
willtebbutt Mar 25, 2025
e1a1260
getrf frule sketch
willtebbutt Mar 25, 2025
fda2ab9
Merge branch 'gd/forward' of https://github.com/gdalle/Mooncake.jl in…
willtebbutt Mar 25, 2025
37baaf0
Improve getrf performance
willtebbutt Mar 26, 2025
c0c4167
trtrs implementation + type stability checks
willtebbutt Mar 26, 2025
9a12b23
Type stability checks for BLAS rules
willtebbutt Mar 26, 2025
bb8feba
Note Seth's blog
willtebbutt Mar 26, 2025
64d6176
getrs frule implementation
willtebbutt Mar 27, 2025
be57d7f
getri frule implementation
willtebbutt Mar 27, 2025
2934409
potrs
willtebbutt Mar 27, 2025
fe289a0
Enable lapack CI
willtebbutt Mar 27, 2025
39354bc
Fix pivoting
willtebbutt Mar 28, 2025
8bcde33
Enable diff tests integration tests
willtebbutt Mar 28, 2025
497c907
Only run extra CI on 1
willtebbutt Mar 28, 2025
e1dce38
More lapack fixes
willtebbutt Mar 28, 2025
8739c6c
widenconst
willtebbutt Mar 28, 2025
899d4c4
Replace field access with method call
willtebbutt Mar 28, 2025
594ba13
Catch __vec_to_tuple edge case
willtebbutt Mar 28, 2025
0510235
Display more stuff when correctness test fails
willtebbutt Mar 28, 2025
4af2276
Enable more integration tests
willtebbutt Mar 28, 2025
83cd097
Make output on test error sensible
willtebbutt Mar 28, 2025
da3d7ee
Tidy up blas implementations
willtebbutt Mar 28, 2025
eee18dd
Fix pointerset error
willtebbutt Mar 28, 2025
9bd274d
Merge branch 'main' into gd/forward
willtebbutt Mar 28, 2025
3a4f70a
Fix ^ rule
willtebbutt Mar 28, 2025
5aed9b2
Implement from_chain_rule macro
willtebbutt Mar 29, 2025
f4f62c9
Get SpecialFunctions extension working
willtebbutt Mar 29, 2025
9c11e6a
Enable SpecialFunctions in CI
willtebbutt Mar 29, 2025
e19cb63
logexpfunctions
willtebbutt Mar 29, 2025
be93bfd
Run gpu jobs on 1.11 only
willtebbutt Mar 29, 2025
1d1e7e9
Restrict FD step for forward mode
willtebbutt Mar 29, 2025
f21b575
Enable GP tests
willtebbutt Mar 29, 2025
c691a92
More integration testing
willtebbutt Mar 29, 2025
b28961a
bijectors
willtebbutt Mar 29, 2025
b67f2c3
Enable battery of tests
willtebbutt Mar 29, 2025
2bdb0ad
Distributions integration testing
willtebbutt Mar 29, 2025
d4fa5c8
Enable DI CI
willtebbutt Mar 29, 2025
4902ce2
Enable reverse-mode integration tests for Lux etc
willtebbutt Mar 29, 2025
7f57a06
Enable 1.10
willtebbutt Mar 31, 2025
60e4d89
Fix LAPACK on 1.10
willtebbutt Mar 31, 2025
41bb3c3
Implement copytrito for 1.10
willtebbutt Mar 31, 2025
2140edd
formatting
willtebbutt Mar 31, 2025
05bac94
Merge branch 'main' into gd/forward
willtebbutt Mar 31, 2025
df0cf38
Tidying up
willtebbutt Mar 31, 2025
dac008f
Remove type piracy
willtebbutt Mar 31, 2025
48b61ec
Initial forwards-mode timings
willtebbutt Mar 31, 2025
3d9f9bf
Merge in main
willtebbutt May 11, 2025
05d3c65
Constrain JuliaInterpreter
willtebbutt May 26, 2025
df0d2d7
Basic MistyClosure support
willtebbutt May 26, 2025
ed912eb
Merge in main
willtebbutt May 26, 2025
b9c5f7e
Do not use MistyClosure internals inside reverse-mode
willtebbutt Jun 4, 2025
6990348
Forwards-over-reverse mwe
willtebbutt Jun 4, 2025
941e171
Remove overly strict performance check
willtebbutt Jun 4, 2025
180b43e
Docstring and improved field naming
willtebbutt Jun 4, 2025
2bbf98d
Separate forward-mode and reverse-mode primitives
willtebbutt Jun 16, 2025
f9151ed
Fix docs and rrule creation
willtebbutt Jun 16, 2025
2d52a2c
Fix low_level_maths
willtebbutt Jun 16, 2025
dd023e7
Fix SpecialFunctions tests cases
willtebbutt Jun 16, 2025
31b1733
Fix more testing
willtebbutt Jun 16, 2025
bfa2476
Fix formatting
willtebbutt Jun 16, 2025
27cd7c1
Make symbols available in tests
willtebbutt Jun 16, 2025
8079137
Fix GP test suite
willtebbutt Jun 16, 2025
f37626f
Fix SpecialFunctions test suite
willtebbutt Jun 16, 2025
2df8d36
Merge branch 'main' into gd/forward
willtebbutt Jun 16, 2025
2275845
Fix performance
willtebbutt Jun 16, 2025
720e410
Fix array tests
willtebbutt Jun 16, 2025
0d16652
Fix formatting
willtebbutt Jun 16, 2025
3aaac95
Fix forward-mode benchmarking
willtebbutt Jun 16, 2025
75f8a76
Fix benchmarking
willtebbutt Jun 16, 2025
ca2e32c
forward mode interface
willtebbutt Jun 17, 2025
9180940
Merge in main
willtebbutt Jun 30, 2025
f5b1e0e
Add frule for eps
willtebbutt Jun 30, 2025
c918f4c
Merge in main
willtebbutt Jul 1, 2025
bb4d4f6
Merge in main
willtebbutt Jul 20, 2025
ae7d8e7
Remove redundant ignore
willtebbutt Jul 20, 2025
0413507
Rename from_chain_rule macro to from_chainrules
willtebbutt Jul 20, 2025
f7ea48d
Finish renaming chainrules macro
willtebbutt Jul 20, 2025
aff7f22
Improve docstring for value_and_derivative
willtebbutt Jul 20, 2025
ef1d4f7
DRY out global interpreter cache
willtebbutt Jul 20, 2025
92370d2
Fix typo in docstring
willtebbutt Jul 20, 2025
274b804
Doctests for is_primitive macro
willtebbutt Jul 20, 2025
b1263cf
Tidying up
willtebbutt Jul 20, 2025
fecea5f
Fix typo in bijectors
willtebbutt Jul 20, 2025
33ab403
Fix test_rule call
willtebbutt Jul 20, 2025
ef9dd72
Fix formatting
willtebbutt Jul 20, 2025
e580bbf
Fix dispatch doctor in forward mode
willtebbutt Jul 20, 2025
a0bab14
Fix import
willtebbutt Jul 20, 2025
c6c376b
Fix doctests
willtebbutt Jul 20, 2025
a46cae5
Fix BLAS tests
willtebbutt Jul 21, 2025
254eeae
Fix DispatchDoctor tests
willtebbutt Jul 21, 2025
45e322d
Fix broken tests
willtebbutt Jul 21, 2025
c602978
Include the mode in testset string
willtebbutt Jul 21, 2025
d95e389
Support try-catch statements
willtebbutt Jul 21, 2025
9c904be
Fix on LTS
willtebbutt Jul 21, 2025
042ff36
Support enter expression
willtebbutt Jul 21, 2025
7b1adbe
Formatting
willtebbutt Jul 21, 2025
a0ca9bf
Enable FunctionWrappers
willtebbutt Jul 21, 2025
e4f816a
Enable GPU CI on LTS
willtebbutt Jul 21, 2025
c804438
Enable function wrappers in CI
willtebbutt Jul 21, 2025
f45068a
Bump patch version and create HISTORY
willtebbutt Aug 3, 2025
e9f8599
Fix typo in prepare_derivative_cache docstring
willtebbutt Aug 3, 2025
136984b
Note new components of public interface
willtebbutt Aug 3, 2025
ab7dc2b
Add docstring to Dual
willtebbutt Aug 3, 2025
60fb2af
Add value_and_derivative and preparation function to exports / public…
willtebbutt Aug 3, 2025
e407c2f
Merge in main
willtebbutt Aug 3, 2025
9a5fc0f
Improve interface docstring
willtebbutt Aug 3, 2025
c76c21c
Remove comment about which we have an open issue
willtebbutt Aug 3, 2025
d9001ed
Typo
willtebbutt Aug 3, 2025
f4dc6f9
Improve documentation
willtebbutt Aug 3, 2025
a26661b
Rename const_dual to make clear that it mutates
willtebbutt Aug 3, 2025
1af8093
Clarify use of insert_node
willtebbutt Aug 3, 2025
24a088b
Improve docstring
willtebbutt Aug 3, 2025
2bd3e9c
Tie todo note to github issue
willtebbutt Aug 3, 2025
50d4a30
Apply suggestions from code review
willtebbutt Aug 3, 2025
b706703
Merge branch 'gd/forward' of https://github.com/gdalle/Mooncake.jl in…
willtebbutt Aug 3, 2025
b7d1116
Make use of inc_args for PiNode in reverse mode
willtebbutt Aug 3, 2025
acc4b9e
Nospecialise on rules and remove redundant comment
willtebbutt Aug 3, 2025
3cb238d
More avoidance of specialisation
willtebbutt Aug 3, 2025
fc6e668
Remove errant nospecialize and directly test rule caching
willtebbutt Aug 3, 2025
fa37327
Add not deepcopying behaviour in test
willtebbutt Aug 6, 2025
d57c184
Improve zero_derivative implementation
willtebbutt Aug 6, 2025
61c9c54
Tidy up zero_adjoint and add deprecated file
willtebbutt Aug 6, 2025
709ce39
Merge in main
willtebbutt Aug 6, 2025
4e00aef
Rename some files
willtebbutt Aug 6, 2025
8c0569d
Fix caching test bug
willtebbutt Aug 6, 2025
6226bb8
Fix avoiding non diff code
willtebbutt Aug 6, 2025
d125324
Tidy up from_chainrules
willtebbutt Aug 7, 2025
d2d4fad
Merge branch 'main' into gd/forward
willtebbutt Aug 7, 2025
ca08ecd
Update HISTORY
willtebbutt Aug 7, 2025
0512f17
Update history
willtebbutt Aug 7, 2025
4c3f54d
Test interface kwargs
willtebbutt Aug 7, 2025
5ec27db
Refine forward-over-reverse test
willtebbutt Aug 7, 2025
36fb2a3
Include MistyClosures in CI
willtebbutt Aug 7, 2025
b8f4798
Remove incorrect test
willtebbutt Aug 7, 2025
ec8c7eb
Ensure that rrule for MistyClosure errors loudly
willtebbutt Aug 7, 2025
ba81ae6
Formatting
willtebbutt Aug 7, 2025
3259e83
Fix flux integration tests
willtebbutt Aug 7, 2025
1e1c227
Formatting
willtebbutt Aug 8, 2025
fe5a1c7
Fix mode access
willtebbutt Aug 8, 2025
3548c25
Remove unnecessary ReverseMode import
willtebbutt Aug 8, 2025
0d6b388
Fix macro bug
willtebbutt Aug 8, 2025
fa1fe13
Formatting
willtebbutt Aug 8, 2025
96f0397
Merge branch 'main' into gd/forward
willtebbutt Aug 8, 2025
f1b4c25
Update misty closure set_to_zero implementation
willtebbutt Aug 8, 2025
7623d1d
Update Project.toml
yebai Aug 11, 2025
aa55284
Merge branch 'main' into gd/forward
yebai Aug 12, 2025
43a3fe7
Merge branch 'main' into gd/forward
yebai Aug 12, 2025
be0e9dd
typofix
yebai Aug 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ jobs:
'rrules/low_level_maths',
'rrules/memory',
'rrules/misc',
'rrules/misty_closures',
'rrules/new',
'rrules/random',
'rrules/tasks',
Expand Down
11 changes: 11 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# 0.4.143

## Public Interface
- Mooncake offers forward mode AD.
Comment on lines +3 to +4
Copy link
Collaborator

@penelopeysm penelopeysm Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's amazing work, though it's not clear how one can use it? I assume that ADTypes / DifferentiationInterface support will take a bit of time to arrive, but in the meantime do I just replace value_and_gradient!! with value_and_derivative!!?

And congratulations to all involved! 🎉

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DI support landing today

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- Two new functions added to the public interface: `prepare_derivative_cache` and `value_and_derivative!!`.
- One new type added to the public interface: `Dual`.

## Internals
- `get_interpreter` was previously a zero-arg function. Is now a unary function, called with a "mode" argument: `get_interpreter(ForwardMode)`, `get_interpreter(ReverseMode)`.
- `@zero_derivative` should now be preferred to `@zero_adjoint`. `@zero_adjoint` will be removed in 0.5.
- `@from_chainrules` should now be preferred to `@from_rrule`. `@from_rrule` will be removed in 0.5.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.146"
version = "0.4.147"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -59,8 +59,9 @@ FunctionWrappers = "1.1.3"
GPUArraysCore = "0.1, 0.2"
Graphs = "1"
InteractiveUtils = "1"
JET = "0.9, 0.10"
JET = "0.9"
JuliaFormatter = "1.0, 2.1"
JuliaInterpreter = "0.9"
LinearAlgebra = "1"
LuxLib = "1"
MistyClosures = "2"
Expand All @@ -81,9 +82,10 @@ DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AllocCheck", "Aqua", "BenchmarkTools", "DiffTests", "JET", "JuliaFormatter", "Pkg", "StableRNGs", "Test"]
test = ["AllocCheck", "Aqua", "BenchmarkTools", "DiffTests", "JET", "JuliaFormatter", "JuliaInterpreter", "Pkg", "StableRNGs", "Test"]
23 changes: 22 additions & 1 deletion bench/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@ using AbstractGPs,
Zygote

using Mooncake:
Dual,
CoDual,
generate_hand_written_rrule!!_test_cases,
generate_derived_rrule!!_test_cases,
TestUtils,
_typeof,
primal,
tangent,
zero_dual,
zero_codual

using Mooncake.TestUtils: _deepcopy

to_benchmark(__frule!!::R, dx::Vararg{Dual,N}) where {R,N} = __frule!!(dx...)

function to_benchmark(__rrule!!::R, dx::Vararg{CoDual,N}) where {R,N}
dx_f = Mooncake.tuple_map(x -> CoDual(primal(x), Mooncake.fdata(tangent(x))), dx)
out, pb!! = __rrule!!(dx_f...)
Expand Down Expand Up @@ -206,6 +210,20 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
evals=1,
)

# Benchmark AD via Mooncake.
@info "Mooncake (Forward)"
rule = Mooncake.build_frule(args...)
duals = map(x -> x isa CoDual ? Dual(x.x, x.dx) : zero_dual(x), args)
to_benchmark(rule, duals...)
include_other_frameworks && GC.gc(true)
suite["mooncake_fwd"] = Chairmarks.benchmark(
() -> (rule, duals),
identity,
a -> to_benchmark(a[1], a[2]...),
_ -> true;
evals=1,
)

if include_other_frameworks
if should_run_benchmark(Val(:zygote), args...)
@info "Zygote"
Expand Down Expand Up @@ -258,6 +276,7 @@ function combine_results(result, tag, _range, default_range)
d = result[2]
primal_time = minimum(d["primal"]).time
mooncake_time = minimum(d["mooncake"]).time
mooncake_fwd_time = minimum(d["mooncake_fwd"]).time
zygote_time = in("zygote", keys(d)) ? minimum(d["zygote"]).time : missing
rd_time = in("rd", keys(d)) ? minimum(d["rd"]).time : missing
ez_time = in("enzyme", keys(d)) ? minimum(d["enzyme"]).time : missing
Expand All @@ -267,6 +286,8 @@ function combine_results(result, tag, _range, default_range)
primal_time=primal_time,
mooncake_time=mooncake_time,
Mooncake=mooncake_time / primal_time,
mooncake_fwd_time=mooncake_fwd_time,
MooncakeFwd=mooncake_fwd_time / primal_time,
zygote_time=zygote_time,
Zygote=zygote_time / primal_time,
rd_time=rd_time,
Expand Down Expand Up @@ -348,7 +369,7 @@ end

function create_inter_ad_benchmarks()
results = benchmark_inter_framework_rules()
tools = [:Mooncake, :Zygote, :ReverseDiff, :Enzyme]
tools = [:Mooncake, :MooncakeFwd, :Zygote, :ReverseDiff, :Enzyme]
df = DataFrame(results)[:, [:tag, :primal_time, tools...]]

# Plot graph of results.
Expand Down
74 changes: 37 additions & 37 deletions ext/MooncakeSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,44 @@ module MooncakeSpecialFunctionsExt
using SpecialFunctions, Mooncake
using Base: IEEEFloat

import Mooncake: @from_rrule, DefaultCtx, @zero_adjoint
import Mooncake: DefaultCtx, @from_chainrules, @zero_derivative

@from_rrule DefaultCtx Tuple{typeof(airyai),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(airyaix),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(airyaiprime),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(airybi),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(airybiprime),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(besselj0),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(besselj1),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(bessely0),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(bessely1),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(dawson),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(digamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erf),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erf),IEEEFloat,IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erfc),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(logerfc),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erfcinv),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erfcx),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(logerfcx),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erfi),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(erfinv),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(gamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(invdigamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(trigamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(polygamma),Integer,IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(beta),IEEEFloat,IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(logbeta),IEEEFloat,IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(logabsgamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(loggamma),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(expint),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(expintx),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(expinti),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(sinint),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(cosint),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(ellipk),IEEEFloat}
@from_rrule DefaultCtx Tuple{typeof(ellipe),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(airyai),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(airyaix),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(airyaiprime),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(airybi),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(airybiprime),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(besselj0),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(besselj1),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(bessely0),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(bessely1),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(dawson),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(digamma),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(erf),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(erf),IEEEFloat,IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(erfc),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(logerfc),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(erfcinv),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(erfcx),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(logerfcx),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(erfi),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(erfinv),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(gamma),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(invdigamma),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(trigamma),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(polygamma),Integer,IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(beta),IEEEFloat,IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(logbeta),IEEEFloat,IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(logabsgamma),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(loggamma),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(expint),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(expintx),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(expinti),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(sinint),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(cosint),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(ellipk),IEEEFloat}
@from_chainrules DefaultCtx Tuple{typeof(ellipe),IEEEFloat}

@zero_adjoint DefaultCtx Tuple{typeof(logfactorial),Integer}
@zero_derivative DefaultCtx Tuple{typeof(logfactorial),Integer}

end
31 changes: 24 additions & 7 deletions src/Mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using Base:
twiceprecision
using Base.Experimental: @opaque
using Base.Iterators: product
using Base.Meta: isexpr
using Core:
Intrinsics,
bitcast,
Expand All @@ -29,6 +30,8 @@ using Core:
GotoIfNot,
PhiNode,
PiNode,
PhiCNode,
UpsilonNode,
SSAValue,
Argument,
OpaqueClosure,
Expand All @@ -43,6 +46,13 @@ using DispatchDoctor: @stable, @unstable
# Needs to be defined before various other things.
function _foreigncall_ end

"""
frule!!(f::Dual, x::Dual...)

Performs AD in forward mode, possibly modifying the inputs, and returns a `Dual`.
"""
function frule!! end

"""
rrule!!(f::CoDual, x::CoDual...)

Expand Down Expand Up @@ -72,10 +82,11 @@ function rrule!! end
build_primitive_rrule(sig::Type{<:Tuple})

Construct an rrule for signature `sig`. For this function to be called in `build_rrule`, you
must also ensure that `is_primitive(context_type, sig)` is `true`. The callable returned by
this must obey the rrule interface, but there are no restrictions on the type of callable
itself. For example, you might return a callable `struct`. By default, this function returns
`rrule!!` so, most of the time, you should just implement a method of `rrule!!`.
must also ensure that `is_primitive(context_type, ReverseMode, sig)` is `true`. The callable
returned by this must obey the rrule interface, but there are no restrictions on the type of
callable itself. For example, you might return a callable `struct`. By default, this
function returns `rrule!!` so, most of the time, you should just implement a method of
`rrule!!`.

# Extended Help

Expand All @@ -95,6 +106,7 @@ build_primitive_rrule(::Type{<:Tuple}) = rrule!!
@stable default_mode = "disable" default_union_limit = 2 begin
include("utils.jl")
include("tangents.jl")
include("dual.jl")
include("fwds_rvs_data.jl")
include("codual.jl")
include("debug_mode.jl")
Expand All @@ -110,7 +122,8 @@ include(joinpath("interpreter", "patch_for_319.jl"))
include(joinpath("interpreter", "ir_utils.jl"))
include(joinpath("interpreter", "ir_normalisation.jl"))
include(joinpath("interpreter", "zero_like_rdata.jl"))
include(joinpath("interpreter", "s2s_reverse_mode_ad.jl"))
include(joinpath("interpreter", "forward_mode.jl"))
include(joinpath("interpreter", "reverse_mode.jl"))
end

include("tools_for_rules.jl")
Expand All @@ -129,6 +142,7 @@ include(joinpath("rrules", "lapack.jl"))
include(joinpath("rrules", "linear_algebra.jl"))
include(joinpath("rrules", "low_level_maths.jl"))
include(joinpath("rrules", "misc.jl"))
include(joinpath("rrules", "misty_closures.jl"))
include(joinpath("rrules", "new.jl"))
include(joinpath("rrules", "random.jl"))
include(joinpath("rrules", "tasks.jl"))
Expand All @@ -146,12 +160,15 @@ include("developer_tools.jl")

# Public, not exported
include("public.jl")

end
#! format: on

@public Config, value_and_pullback!!, prepare_pullback_cache
@public Config, value_and_pullback!!, prepare_pullback_cache, value_and_derivative!!
@public prepare_derivative_cache, Dual

# Public, exported
export value_and_gradient!!, prepare_gradient_cache
export value_and_gradient!!, prepare_gradient_cache, value_and_derivative!!
export prepare_derivative_cache

end
7 changes: 7 additions & 0 deletions src/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ tangent(x::CoDual) = x.dx
Base.copy(x::CoDual) = CoDual(copy(primal(x)), copy(tangent(x)))
_copy(x::P) where {P<:CoDual} = x

"""
extract(x::CoDual)

Helper function. Returns the 2-tuple `x.x, x.dx`.
"""
extract(x::CoDual) = primal(x), tangent(x)

"""
zero_codual(x)

Expand Down
2 changes: 2 additions & 0 deletions src/debug_mode.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: make it non-trivial. See https://github.com/chalk-lab/Mooncake.jl/issues/672
DebugFRule(rule) = rule

"""
DebugPullback(pb, y, x)
Expand Down
Loading