-
Notifications
You must be signed in to change notification settings - Fork 29
Demo of compile-time rule generation #593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
d9a5d06 to
fb1b880
Compare
|
This is very cool @MasonProtter . Before reviewing, I just want to make sure that I've correctly understood what you've done in this PR. Is the below roughly correct? "This PR introduces a generated function which returns the |
|
Yes @willtebbutt that's correct. However, there's also a catch with this approach, which gives it some downsides: it seems like an opaque closure created in a generated function cannot be inlined, which means that you hit allocations from actually calling this rule. I have a branch where I did did this for the forward diff stuff, and the effect is easier to see there: julia> function foo(f, x)
rule = Mooncake.generated_build_frule(Tuple{typeof(f), Float64})
rule(Dual(f, Mooncake.NoTangent()), Dual(x, 1.0))
end;
julia> @btime foo(x -> sin(x + 1), 1.0)
39.581 ns (3 allocations: 96 bytes)
Dual{Float64, Float64}(0.9092974268256817, -0.4161468365471424)whereas if we compute the rule 'normally' the compiler is allowed to call it without overhead: julia> function bar(rule, f, x)
rule(Dual(f, Mooncake.NoTangent()), Dual(x, 1.0))
end;
julia> let f = x -> sin(x + 1)
rule = Mooncake.build_frule(Mooncake.get_interpreter(), Tuple{typeof(f), Float64})
@btime bar($rule, $f, 1.0)
end
10.881 ns (0 allocations: 0 bytes)
Dual{Float64, Float64}(0.9092974268256817, -0.4161468365471424)I'll investigate a bit at some point if this is possible to overcome. |
|
Great, glad I've understood. Regarding the performance: it seems odd that this introduces allocations. Typically all of the operations associated to passing |
|
Passing around opaque closures is cheap when the compiler is allowed to inline them, but if it's not, then they're expensive. Here's an example of that: julia> function make_oc()
Base.Experimental.@opaque (x::Int) -> sin(x - 1)
end;
julia> @generated function make_oc_gen()
Base.Experimental.@opaque (x::Int) -> sin(x - 1)
end;
julia> @btime f(x) setup=begin
x = 1
f = make_oc()
end
3.566 ns (0 allocations: 0 bytes)
0.0
julia> @btime f(x) setup=begin
x = 1
f = make_oc_gen()
end
38.581 ns (2 allocations: 64 bytes)
0.0 |
willtebbutt
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some thoughts on design -- will review for style etc once we're all happy with this. Thanks again for opening this PR, I'm excited about it.
| # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater | ||
| # than the current world age. | ||
| if Base.get_world_counter() > interp.world | ||
| if world > interp.world |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before merging, we should probably tweak the way that Mooncake handles world ages generally. The current approach is something of a hack (notice, for example, that DynamicDerivedRules don't accept a MooncakeInterpreter, and therefore don't have a pinned world age -- they just always grab a new interpreter on-the-fly, and use it to derive a rule). We can worry about that a bit later though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that'd be a good idea. I think it'd also potentially be a good idea to keep around a dictionary of all the different interpreters from different worlds, rather than just the latest interpreter. Not sure how you feel about that though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea -- what do you have in mind? Make get_intepreter accept a world age + make the constant containing the interpreter a dictionary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah exactly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be great. My thinking is that we should wind up removing the above check entirely, and just rely on correctly propagating the world age through to everything via the interpreter.
Is this something you would be interested in doing as part of this PR, or would you rather it were done as part of a separate one?
| interp = MooncakeInterpreter(DefaultCtx; world) | ||
| rule = build_rrule(interp, sig; world) | ||
|
|
||
| ci = expr_to_codeinfo(@__MODULE__(), [Symbol("#self#"), :sig], [], (), :(return $rule)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's probably true that just returning this is unsafe because the DerivedRule contains various bits of state which get read from / written to in specific orders during the forwards- and reverse passes. The caching functionality in build_rrule therefore makes a call to _copy, which ensures that each time we ask for a rule, it gets fresh state which is definitely not shared with any other rule for this signature. Ensuring that no two instances of a rule share the same state ensures that we avoid race conditions if we're running the same rule on multiple threads, and makes recursion work correctly (the way that we handle recursion is a bit weird).
Anyway, we'll have to do something similar here. You could simply replace the :(return $rule) with :(return _copy($rule). This will obviously have worse performance (not allocation-free), but is probably going to be better than what we have at the minute because it will still avoid the type unstable code here).
In order to do better, we would need to avoid _copy altogether. I suspect that we'll have to do some clever caching here to make this work well. I suspect that it can be done, but maybe it's something to do in a follow-up PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, makes sense. I didn't know there was state inside the rule as well, that's unfortunate.
|
You likely need to add |
|
Thanks for letting us know! I would certainly not have thought about this. |
|
So we'll likely need a different approach then rather than disabling precompilation (unless we stick the compile-time-generated version in a separate module that has precompilation disabled? But that's a pretty unsatisfying solution and probably quite brittle). An alternative route we could take here would be to wait for v1.12, and use the mechanism introduced in JuliaLang/julia#56660 to try and do this. Of course, if JuliaLang/julia#56650 were to be resurrected that'd be much better for Mooncake. It seems like @Keno is still somewhat open to that design and having it complement 56660. Maybe we can discuss with him sometime. |
|
Closed in favour of #900 |
This is a little demo followup to #498
I made a generated function
Mooncake.generated_build_rrulethat is able to create anrruleat compile time and return it as a compiletime const. Here's a little demo of it in action:This thing only has the overhead of passing around a struct:
And it carries backedges to
fso that if a relevant methodistance is invalidated, then therruleis re-compiled: