-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add within_gradient
#434
Add within_gradient
#434
Conversation
a532b00
to
5fd4a91
Compare
|
I may have suggested that. But aren't we (in english) differentiating the function, not the array Maybe worth trying to line up with whatever no_gradient function we need too. |
Is that argument |
I think x is needed for Tracker to work, as it will never notice a function call which doesn't get a TrackedArray. |
I was thinking of the correspondence with CRC.non_differentiable, but agree that the name gets weird when you introduce that array arg. FWIW, PyTorch calls it Edit: I just realized NNlib already has |
Hah I completely forgot that I already added that... and it's still there: Lines 90 to 91 in 0c8396e
It could just grow to |
Doesn't function myf(x, y)
if is_differentiating(myf, x, y)
# do something special
else
# normal
end
end I think we want the convention to pass everything that would get passed to |
Can you explain more what you see this version which takes a function doing? Does it call the function, or infer something about it (like whether it returns a non-diff type or a TrackedArray or something)? Or would you overload this for some functions, to have non-standard behaviour? Or would |
Presumably, no, it does not make sense to overload specific functions. If you want that behavior, then why even have |
I guess I don't see what's gained by passing it more arguments. You can check |
If I want the function to work regardless of AD system, then I need to check all (assuming I care about all of them). You could certainly write Either way, my point was to bring up two things:
|
One reason not to write the splat version is that dispatch doesn't work very well for that. You need a lot of methods to allow for the 3rd being the only TrackedArray, without ambiguity. One reason to dislike
Here In something like https://github.com/FluxML/Flux.jl/blob/master/src/layers/normalise.jl#L109-L110 ... "training" probably means the layer's parameters are tracked, ideally, but not whether the data is. |
I would say |
Yes. The idea seems very clear for Tracker. But for non-tracker AD it seems a bit fragile really. What a Flux layer wants to know is whether you are training. But what it checks is whether the forward pass calls the In fact it looks like Yota is smart enough to do that:
|
I am a little confused. For me, it seems we are actually differentiate wrt |
Paging @dfdx for a real answer, but I think it's assuming that the forward pass of |
Correct. In Yota, One way to hide these details from Yota (and, perhaps, from all CR-based AD systems) is to add the batch_norm(...; training) = ...
rrule(batch_norm, ...; training) = ... Since AD usually doesn't analyze the contents of the |
Thanks! The hope with this rule is to magically infer whether a given evaluation of My attempt to invent a better magic rule still seems to be defeated, I guess it traces it and remembers which branch was taken, before replacing the function with the rrule.
So what's left for Flux? |
Re branches in Yota. Yes, in the latest version tracing and rrule-transformation are two different steps, which turned to be a much, much more robust implementation than a unified single-step pass. Thus, tracer sees only the primal definition of But you don't have to trick Yota - if the consensus is to use As a general thought though, if we wish to inject some special calls into functions to account for control flow, then to me the most logical step would be to add control flow constructs themselves, e.g. like JAX's |
Even with something like |
Yes, definitely, there's no way to avoid a common library. My point is not to avoid the dependency, but to minimize the API that AD systems must support. Having conditionals as a part of the API is quite common and seems to help in this case too (hopefully not only for Yota :D). |
It seems quite tricky. I think Another option would be to demand that every AD system set some I believe that's a global namespace, so different packages could actually all set |
RE task local state, my understanding is that it doesn't propagate to child tasks and would thus could be a problem for differentiating async code. That and the risk of the AD not cleaning task-local state up correctly on errors, leading to false positives for code outside of AD.
I was under the impression that something like |
This issue of branches within AD may also deserve some though in relation to JuliaDiff/ForwardDiff.jl#480 . That's at last fixed on ForwardDiff, where it seemed to cause the most problems as e.g. you push numbers through julia> gradient(prod1, x)[1] # correct
1×5 Matrix{Float32}:
0.0 0.0 24.0 0.0 0.0
julia> Diffractor.gradient(prod2, x)[1] # wrong
1×5 Matrix{Float32}:
0.0 0.0 2.0 0.0 0.0
# Same for Tracker, Zygote, Enzyme, but no longer for ForwardDiff.
julia> ChainRulesCore.@non_differentiable eltype(::Any)
julia> Yota.grad(prod2, x) # worse?
(0.0f0, (ZeroTangent(), NoTangent())) |
Symbolic condition operator existed in TensorFlow 1 (created in 2015), Theano (created in 2008) and perhaps even earlier. I think my cognitive resistance to accept equivalence of
This particular example doesn't sound lie a problem. Given a function: function foo()
if any(within_gradient, (x,x))
return f(...)
else
return g(...)
end
end Zygote should be able to analyze it's IR and replace all calls to bar(x) = any(within_gradient, (x,x))
function foo()
if bar(x)
return f(...)
else
return g(...)
end
end But for me it looks fair enough to just forbid such usage.
Thanks for the report! I've created an issue to track it. |
Thanks for the correction, I misremembered those operations as behaving more like
Currently that example will hit https://github.com/JuliaDiff/ChainRules.jl/blob/v1.44.5/src/rulesets/Base/nondiff.jl#L99, so AD never gets to generate a pullback for So a |
Shall we touch this up & merge it? It's not perfect, but nor is what we have right now. |
Let's do this. @dfdx, do you mind fleshing that |
We can add this here, but ChainRulesCore seems a better place. |
One reason not ChainRules is that we'd like this to work for ForwardDiff (FluxML/Flux.jl#2122) and Tracker (FluxML/Tracker.jl#133), and that's really the point of moving it from Flux, and neither depend on CR. Today's 744309a adds ForwardDiff via Requires. Another is that this is a bit of a dodgy mechanism, it's what Flux uses but really needs tests for each use to make sure AD hasn't out-smarted us. |
Why don't we make Tracker and ForwardDiff depend on CRC (if not CR altogether)? |
Tracker could, no big deal. ForwardDiff I think is not interested. But the more important half is that this is a hacky solution, which is no worse than Flux's present state. |
What about DiffRules.jl? I don't know what's its purpose but Zygote, ForwardDiff and Tracker depend on it. |
Seems like that's not its mission -- old stable package that does exactly one thing. And, same problem, that this is a bit of a hack which I don't think should be encouraged as the permanent canonical solution. But, absent better ideas, it's what Flux does right now, and this PR lets us fix FluxML/Tracker.jl#133 and FluxML/Flux.jl#2122 . What are the downsides? Besides not really being the permanent canonical solution, because we don't have one. |
yeah maybe I was thinking of a more permanent solution around which the ecosystem can grow instead of hacking here and there. It just seems odd to have this function here, and we also have to add Requires. Anyways it's fine, no big deal, and better than where we were. |
Yes something better would be nice. But I honestly don't see how; in the example |
IMO the current best place for a long-term solution is JuliaDiff/AbstractDifferentiation.jl#66. But we don't want to wait for that package to mature just to get Tracker.jl working for basic Flux models, so this is a sensible stopgap. |
We can make the warning scarier if we'd like to discourage use outside of Flux etc. |
Xref a similar function now added to EnzymeCore: EnzymeAD/Enzyme.jl#1851 |
Motivated by FluxML/Tracker.jl#133
Possibly it could have a better name, but what?