-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux support (#457) #485
Flux support (#457) #485
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AD tests run through and the code seems clean. Someone with a deeper understanding should check again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. I've added some thoughts directly in the comments; nothing major though.
src/transform.jl
Outdated
@@ -354,9 +354,9 @@ logpdf_with_trans(d::PDMatDistribution, x::Array{T,2}, transform::Bool) where {T | |||
U = cholesky(x).U | |||
n = dim(d) | |||
for i in 1:n | |||
lp += (n - i + T(2)) * log(U[i, i]) | |||
lp += (n - i + 2.0) * log(U[i, i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there are reason that we have changed from T(2)
to simply 2.0
? Might this not introduce a type instability if lp
isn't a Float64
? (I don't know if this is currently possible)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flux.TrackedReal doesn't have this construction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but
x = Flux.Tracker.TrackedReal(5.0)
y = x * T(2)
makes y
a TrackedReal
. The point is that constants don't need to be Tracked
, but can interact with Tracked
quantities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. But with the signature, T
is TrackedReal
. What's the neat way to make this generic, something like TrackedReal{T}
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a callback function for TrackedReal would be OK?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes, you're completely correct. @MikeInnes any thoughts on best practice here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be preferably to avoid adding any TrackedReal
-specific code. If there's nothing easy to do, then the current implementation is probably fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is use of Union
type stable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what context? (it depends)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just pushed a solution to this. Please check.
src/core/ad.jl
Outdated
else | ||
ctape = spl.info[:reverse_diff_cache][:ctape] | ||
res = spl.info[:reverse_diff_cache][:res] | ||
f_r(ipts) = begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be an idea to either comment this a bit more thoroughly or change the variable names to self-document a bit better. Probably both. Alternatively, just make f_r
an anonymous function, as it's only used in the gradient
call: that would avoid the need to choose a name for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
src/core/ad.jl
Outdated
|
||
# vi[spl] = realpart(vi[spl]) | ||
# vi.logp = 0 | ||
grad = Tracker.gradient(f_r, theta) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grad
only appears to be used once in the final line. Perhaps don't make a separate variable called grad
, or move it to just before it's used? If there's state in grad
that is being mutated on lines 121-125, and hence preventing this line from being moved, please document this.
src/core/ad.jl
Outdated
@@ -108,38 +110,20 @@ gradient2(_vi::VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin | |||
g(vi[spl]) | |||
end | |||
|
|||
@init @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin | |||
|
|||
gradient_r(theta::Vector{Float64}, vi::VarInfo, model::Function) = gradient_r(theta, vi, model, nothing) | |||
gradient_r(theta::Vector{Float64}, vi::Turing.VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line and the one preceding it overruns the 92 character limit. Probably an idea to refactor.
Also, do we really need the concrete Vector{Float64}
? Would AbstractVector{<:Real}
not suffice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I appreciate you've not modified this line, but since we're here...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these changes are fine. I'm happy to merge. (Once whatever is causing the build to fail is resolved)
src/core/ad.jl
Outdated
@@ -94,7 +96,7 @@ end | |||
|
|||
# Direct call of ForwardDiff.gradient; this is slow | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This blank line is probably unnecessary.
@xukai92 perhaps we should consider dropping the dependency on Stan by implementing our own |
Yes I think we should make our internal types for adaptation better instead of keep supporting Stan's. This inconsistence also causes problem like #448. |
Is this something that we want to resolve in this PR, or do it separately? |
it's better to do this in a separate PR! |
Replace ReverseDiff with Flux (#457)
Related issue: #462