You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.
We can pretty much already support this due to the design of the function ∇ in core.jl, specifically the method
∇(y::Node{T}, ȳ::T) where T
It allows one to pass in reverse-mode sensitivities ȳ and doesn't restrict y to be scalar, meaning that we can chain together Tapes. This is useful because it allows us to define primitive-like functions / functors which can use Nabla to know how to differentiate themselves, and use these inside larger functions / functors in the same way that you would a primitive.
If, for example, one wished to create a self-contained differentiable layer for an MLP which handles it's own parameters one could create a functor and make it a primitive. You could do something along the following lines:
mutable struct Layer{T:<AbstractMatrix, V<:Node}
W::T
∇W::T
out::V
end
function (l::Layer)(x::AbstractVector)
W_, x_ = Leaf.(Tape(), (l.W, x))
y_ = tanh.(W_ * x_)
l.out = y_ # Log the result of the forward-pass and the intermediate qtts required for the reverse-pass.
return y_.val
end
@explicit_intercepts Layer Tuple{∇Real}
function ∇(l::Layer, ::Type{Arg{1}}, p, y, ȳ, x)
∇layer = ∇(l.out, ȳ) # Perform the reverse-pass.
l.∇W = ∇layer[1] # Log the gradient of `W` on this pass through.
return ∇(l.out)[2] # Return the reverse-mode sensitivity w.r.t. `x`.
end
This probably won't compile (and I'm not actually sure about some of the syntax re functors), but it illustrates the general idea: a functor could be used to implement a primitive-like block of code which keeps track of it's own parameters and knows how to differentiate itself, without having to hand-code the gradients. (There are a few technical issues of consistency and scoping that I have brushed over e.g. how do you ensure that the x passed in to ∇ is the same x passed to the Layer functor initially, but I'm going to assume that these could be resolved with some (relatively) minor changes to core.jl).
This seems generally like a useful thing to be able to do, so it would probably make sense to figure out the details and partially automate the process with a macro.
The text was updated successfully, but these errors were encountered:
Sign up for freeto subscribe to this conversation on GitHub.
Already have an account?
Sign in.
It would be useful to support stuff like this:
denizyuret/Knet.jl#144
We can pretty much already support this due to the design of the function
∇
incore.jl
, specifically the methodIt allows one to pass in reverse-mode sensitivities
ȳ
and doesn't restricty
to be scalar, meaning that we can chain togetherTape
s. This is useful because it allows us to define primitive-like functions / functors which can use Nabla to know how to differentiate themselves, and use these inside larger functions / functors in the same way that you would a primitive.If, for example, one wished to create a self-contained differentiable layer for an MLP which handles it's own parameters one could create a functor and make it a primitive. You could do something along the following lines:
This probably won't compile (and I'm not actually sure about some of the syntax re functors), but it illustrates the general idea: a functor could be used to implement a primitive-like block of code which keeps track of it's own parameters and knows how to differentiate itself, without having to hand-code the gradients. (There are a few technical issues of consistency and scoping that I have brushed over e.g. how do you ensure that the
x
passed in to∇
is the samex
passed to theLayer
functor initially, but I'm going to assume that these could be resolved with some (relatively) minor changes tocore.jl
).This seems generally like a useful thing to be able to do, so it would probably make sense to figure out the details and partially automate the process with a macro.
The text was updated successfully, but these errors were encountered: