Skip to content

Forwards-Mode Design Docs#386

Merged
willtebbutt merged 15 commits intomainfrom
wct/forwards-mode-design-docs
Nov 27, 2024
Merged

Forwards-Mode Design Docs#386
willtebbutt merged 15 commits intomainfrom
wct/forwards-mode-design-docs

Conversation

@willtebbutt
Copy link
Collaborator

This is initial work on my thinking regarding forwards-mode AD. I wanted to have finished a complete draft today, but the day got away from me. A complete draft should be finished tomorrow, at which point I'll ask for review.

@codecov
Copy link

codecov bot commented Nov 21, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Files with missing lines Coverage Δ
src/debug_mode.jl 100.00% <ø> (ø)
src/interpreter/ir_normalisation.jl 89.39% <ø> (ø)
src/rrules/builtins.jl 97.93% <ø> (ø)

@github-actions
Copy link
Contributor

github-actions bot commented Nov 21, 2024

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬─────────┬─────────────┬─────────┐
│                      Label │ Mooncake │  Zygote │ ReverseDiff │  Enzyme │
│                     String │   String │  String │      String │  String │
├────────────────────────────┼──────────┼─────────┼─────────────┼─────────┤
│                   sum_1000 │     71.6 │     1.0 │         5.5 │ missing │
│                  _sum_1000 │     6.82 │  1360.0 │        32.8 │ missing │
│               sum_sin_1000 │     2.24 │     1.7 │        10.4 │ missing │
│              _sum_sin_1000 │     2.74 │   265.0 │        13.7 │ missing │
│                   kron_sum │     57.9 │    3.46 │       191.0 │ missing │
│              kron_view_sum │     58.9 │    9.54 │       252.0 │ missing │
│      naive_map_sin_cos_exp │     2.55 │ missing │        7.17 │ missing │
│            map_sin_cos_exp │     2.71 │    1.52 │        6.06 │ missing │
│      broadcast_sin_cos_exp │     2.59 │    3.05 │        1.46 │ missing │
│                 simple_mlp │     6.07 │    2.61 │        10.1 │ missing │
│                     gp_lml │     13.4 │     6.6 │     missing │ missing │
│ turing_broadcast_benchmark │     3.25 │ missing │        23.3 │ missing │
│         large_single_block │     4.01 │  4160.0 │        30.0 │ missing │
└────────────────────────────┴──────────┴─────────┴─────────────┴─────────┘

@willtebbutt willtebbutt changed the title WIP: Forwards-Mode Design Docs Forwards-Mode Design Docs Nov 22, 2024
@willtebbutt
Copy link
Collaborator Author

willtebbutt commented Nov 22, 2024

@gdalle I think you could take a look at this now. Please feel free to push directly to it if there are changes that you would like to make. (build is available at the bottom of the CI run list).

@gdalle
Copy link
Collaborator

gdalle commented Nov 23, 2024

Thank you for this very useful clarification! My two remaining questions are:

  • Is there a reason to use both rule_for_fun and frule!!(::typeof(fun)) in the tutorial? Do they refer to different things?
  • Is there a better terminology than compile time vs run time in this case? Because both rule generation and rule execution technically happen at run time?

@willtebbutt
Copy link
Collaborator Author

Is there a reason to use both rule_for_fun and frule!!(::typeof(fun)) in the tutorial? Do they refer to different things?

My intention here was to make rule_for_fun be some as-yet-unspecified callable which does what I say that a rule ought to do (without specifying how). i.e. I just want to use rule_for_fun to set up what the rule interface looks like.

frule!! and DerivedFRule are meant to be the concrete instantiations of things which satisfy the rule interface.

Do you think I should add a sentence somewhere to make this clearer?

Is there a better terminology than compile time vs run time in this case? Because both rule generation and rule execution technically happen at run time?

Good point. Do you think just referring to "rule compilation time" and "rule run time" or something would make sense? Or perhaps "rule generation" and "rule execution", as you propose?

@gdalle
Copy link
Collaborator

gdalle commented Nov 23, 2024

Do you think I should add a sentence somewhere to make this clearer?

Yes.

Good point. Do you think just referring to "rule compilation time" and "rule run time" or something would make sense? Or perhaps "rule generation" and "rule execution", as you propose?

Either is fine I guess.

@willtebbutt
Copy link
Collaborator Author

willtebbutt commented Nov 23, 2024

Great. I've tweaked the doc to reflect the changes. I was only able to find a single example of discussing compile time when I searched the document -- possibly I've missed more examples where I discuss it?

Either way, what do you make of this? Do you think this plan seems sensible?

@gdalle
Copy link
Collaborator

gdalle commented Nov 23, 2024

I think it makes sense but I won't know what info I've missing until I've started!

@gdalle
Copy link
Collaborator

gdalle commented Nov 24, 2024

@willtebbutt could you clarify why the function argument in this section appears in a QuoteNode? It's one of the more confusing aspects of metaprogramming for me. I also wonder how to decide if we need to wrap the callee f in a Dual or if it is already one.

@gdalle gdalle mentioned this pull request Nov 24, 2024
11 tasks
@gdalle
Copy link
Collaborator

gdalle commented Nov 24, 2024

I have a more philosophical question about this page, which can also clarify reverse mode. The Mooncake implementation replaces f with rule_f inside the IR and then steps outside of the IR to use multiple dispatch on the rule. Is it the only way to modify the IR recursively? Or could we for instance inline all of it?

@yebai
Copy link
Member

yebai commented Nov 27, 2024

Hi, I will share my thoughts on the objectives of a new ForwardDiff implementation. I hope it is helpful and happy to discuss it further. (This is copied from a private disscussion with @willtebbutt.)

  • chunk-mode support like current ForwardDiff and multithreading support
  • remove constraints as documented here where possible
  • high performance
  • GPU friendly
  • stability and low maintenance burden

@gdalle
Copy link
Collaborator

gdalle commented Nov 27, 2024

Agreed, and this is shaping up rather nicely in #389.
@willtebbutt maybe we can merge this docs PR and then revisit it as we make progress?

@willtebbutt
Copy link
Collaborator Author

Sounds good. I'm going to quickly add a section comparing ForwardDiff.jl to this (mainly to make sure we've noted down some of the things that @yebai has noted above), then I'll merge.

@willtebbutt
Copy link
Collaborator Author

I have a more philosophical question about this page, which can also clarify reverse mode. The Mooncake implementation replaces f with rule_f inside the IR and then steps outside of the IR to use multiple dispatch on the rule. Is it the only way to modify the IR recursively? Or could we for instance inline all of it?

Apologies, I missed this. I'm not quite sure what you mean by and then steps outside of the IR to use multiple dispatch on the rule -- could you elaborate a bit @gdalle ?

@gdalle
Copy link
Collaborator

gdalle commented Nov 27, 2024

I mean that most of the differentiation work happens outside of the IR, by hooking into alternative high-level Julia functions (frule(f, ...) instead of f). Is it the only way to do it?
Intuitively, my expectation was that there would exist some "completely inlined" version of the IR where we would do all the changes at once?

@willtebbutt
Copy link
Collaborator Author

I think this is more or less the only way to do it. One reason is that, since dynamic dispatch is a thing, you don't know what all of the code is at rule-generation time, so there's no way that you could inline everything.

@willtebbutt willtebbutt merged commit 0f37c07 into main Nov 27, 2024
@willtebbutt willtebbutt deleted the wct/forwards-mode-design-docs branch November 27, 2024 12:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants