Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux support (#457) #485

Merged
merged 9 commits into from
Sep 4, 2018
Merged

Flux support (#457) #485

merged 9 commits into from
Sep 4, 2018

Conversation

xukai92
Copy link
Member

@xukai92 xukai92 commented Sep 2, 2018

Replace ReverseDiff with Flux (#457)

Related issue: #462

Copy link
Member

@trappmartin trappmartin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AD tests run through and the code seems clean. Someone with a deeper understanding should check again.

@willtebbutt willtebbutt self-assigned this Sep 3, 2018
@willtebbutt willtebbutt mentioned this pull request Sep 3, 2018
Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great. I've added some thoughts directly in the comments; nothing major though.

src/transform.jl Outdated
@@ -354,9 +354,9 @@ logpdf_with_trans(d::PDMatDistribution, x::Array{T,2}, transform::Bool) where {T
U = cholesky(x).U
n = dim(d)
for i in 1:n
lp += (n - i + T(2)) * log(U[i, i])
lp += (n - i + 2.0) * log(U[i, i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there are reason that we have changed from T(2) to simply 2.0? Might this not introduce a type instability if lp isn't a Float64? (I don't know if this is currently possible)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flux.TrackedReal doesn't have this construction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but

x = Flux.Tracker.TrackedReal(5.0)
y = x * T(2)

makes y a TrackedReal. The point is that constants don't need to be Tracked, but can interact with Tracked quantities.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. But with the signature, T is TrackedReal. What's the neat way to make this generic, something like TrackedReal{T}?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a callback function for TrackedReal would be OK?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, you're completely correct. @MikeInnes any thoughts on best practice here?

Copy link
Member

@willtebbutt willtebbutt Sep 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be preferably to avoid adding any TrackedReal-specific code. If there's nothing easy to do, then the current implementation is probably fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is use of Union type stable?

Copy link
Member

@willtebbutt willtebbutt Sep 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what context? (it depends)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed a solution to this. Please check.

src/core/ad.jl Outdated
else
ctape = spl.info[:reverse_diff_cache][:ctape]
res = spl.info[:reverse_diff_cache][:res]
f_r(ipts) = begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be an idea to either comment this a bit more thoroughly or change the variable names to self-document a bit better. Probably both. Alternatively, just make f_r an anonymous function, as it's only used in the gradient call: that would avoid the need to choose a name for it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

src/core/ad.jl Outdated

# vi[spl] = realpart(vi[spl])
# vi.logp = 0
grad = Tracker.gradient(f_r, theta)
Copy link
Member

@willtebbutt willtebbutt Sep 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad only appears to be used once in the final line. Perhaps don't make a separate variable called grad, or move it to just before it's used? If there's state in grad that is being mutated on lines 121-125, and hence preventing this line from being moved, please document this.

src/core/ad.jl Outdated
@@ -108,38 +110,20 @@ gradient2(_vi::VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
g(vi[spl])
end

@init @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin

gradient_r(theta::Vector{Float64}, vi::VarInfo, model::Function) = gradient_r(theta, vi, model, nothing)
gradient_r(theta::Vector{Float64}, vi::Turing.VarInfo, model::Function, spl::Union{Nothing, Sampler}) = begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line and the one preceding it overruns the 92 character limit. Probably an idea to refactor.
Also, do we really need the concrete Vector{Float64}? Would AbstractVector{<:Real} not suffice?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I appreciate you've not modified this line, but since we're here...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, documentation.

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these changes are fine. I'm happy to merge. (Once whatever is causing the build to fail is resolved)

src/core/ad.jl Outdated
@@ -94,7 +96,7 @@ end

# Direct call of ForwardDiff.gradient; this is slow

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This blank line is probably unnecessary.

@yebai
Copy link
Member

yebai commented Sep 4, 2018

@xukai92 perhaps we should consider dropping the dependency on Stan by implementing our own Hmc and Adapt types.

@xukai92
Copy link
Member Author

xukai92 commented Sep 4, 2018

Yes I think we should make our internal types for adaptation better instead of keep supporting Stan's. This inconsistence also causes problem like #448.

@willtebbutt
Copy link
Member

perhaps we should consider dropping the dependency on Stan

Is this something that we want to resolve in this PR, or do it separately?

@yebai
Copy link
Member

yebai commented Sep 4, 2018

Is this something that we want to resolve in this PR, or do it separately?

it's better to do this in a separate PR!

@yebai yebai merged commit 7cb84a5 into master Sep 4, 2018
@yebai yebai deleted the flux-support branch September 4, 2018 21:40
yebai added a commit that referenced this pull request Sep 18, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants