Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add within_gradient #434

Merged
merged 4 commits into from
Jan 5, 2023
Merged

Add within_gradient #434

merged 4 commits into from
Jan 5, 2023

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 30, 2022

Motivated by FluxML/Tracker.jl#133

Possibly it could have a better name, but what?

@mcabbott mcabbott force-pushed the istraining branch 2 times, most recently from a532b00 to 5fd4a91 Compare August 30, 2022 15:54
@ToucheSir
Copy link
Member

is_differentiating was suggested at some point, IIRC?

@mcabbott
Copy link
Member Author

I may have suggested that. But aren't we (in english) differentiating the function, not the array x?

Maybe worth trying to line up with whatever no_gradient function we need too.

@chengchingwen
Copy link
Member

Is that argument x needed? It seems we can just define within_gradient() = false since we didn't (or can't?) detect whether the function is differentiated wrt x.

@mcabbott
Copy link
Member Author

I think x is needed for Tracker to work, as it will never notice a function call which doesn't get a TrackedArray.

@ToucheSir
Copy link
Member

ToucheSir commented Aug 30, 2022

I may have suggested that. But aren't we (in english) differentiating the function, not the array x?

Maybe worth trying to line up with whatever no_gradient function we need too.

I was thinking of the correspondence with CRC.non_differentiable, but agree that the name gets weird when you introduce that array arg.

FWIW, PyTorch calls it GradMode::is_enabled() (ex).

Edit: I just realized NNlib already has within_gradient() after reading JuliaDiff/ChainRulesCore.jl#547. To make the purpose of one-arg method clearer, x could be rebranded as test_val or some such and specifically documented.

@mcabbott
Copy link
Member Author

Hah I completely forgot that I already added that... and it's still there:

NNlib.jl/src/softmax.jl

Lines 90 to 91 in 0c8396e

within_grad() = false
rrule(::typeof(within_grad)) = true, _ -> (NoTangent(),)

It could just grow to within_grad(x...) or something? Maybe Tracker would like the use of it in ∇softmax to check the array.

@darsnack
Copy link
Member

Doesn't is_differentiating make sense from the user's perspective? Like

function myf(x, y)
    if is_differentiating(myf, x, y)
        # do something special
    else
        # normal
    end
end

I think we want the convention to pass everything that would get passed to rrule. e.g. x could be a TrackedArray but not y. And in english, "are we differentiating myf w.r.t. x or y?"

@mcabbott
Copy link
Member Author

mcabbott commented Aug 30, 2022

Can you explain more what you see this version which takes a function doing? Does it call the function, or infer something about it (like whether it returns a non-diff type or a TrackedArray or something)? Or would you overload this for some functions, to have non-standard behaviour? Or would is_differentiating(myf, x, y) == is_differentiating(identity, x, y) always?

@darsnack
Copy link
Member

darsnack commented Aug 30, 2022

Presumably, no, it does not make sense to overload specific functions. If you want that behavior, then why even have is_X as a check at all in the definition of f? But I was thinking from the Tracker perspective, you want to test if any of the inputs are tracked. I only included the function for consistency.

@mcabbott
Copy link
Member Author

I guess I don't see what's gained by passing it more arguments. You can check any(within_gradient, (x, y, z)), or it could be spelled within_grad(x...). You may only care about whether a particular argument is tracked, or know that it's sufficient to test one.

@darsnack
Copy link
Member

If I want the function to work regardless of AD system, then I need to check all (assuming I care about all of them). You could certainly write any(within_gradient, (x, y, z)).

Either way, my point was to bring up two things:

  1. is_deriving(x)/is_differentiating(x) makes more sense in english to me, because you care about whether you are currently taking a derivative with respect to the arguments passed. within_gradient seems more appropriate for something that never takes any arguments.
  2. Given (1), I would then calling any(is_deriving, (x, y, z)) seems like the most common use case for the function (versus caring about only a specific argument). For some AD systems, these are equivalent, but others not. So shortening this common case to is_deriving(x, y, z) seemed worth it. Only a small comment, no need to follow it if we don't think it is worth it.

@mcabbott
Copy link
Member Author

One reason not to write the splat version is that dispatch doesn't work very well for that. You need a lot of methods to allow for the 3rd being the only TrackedArray, without ambiguity.

One reason to dislike any is that Zygote knows it returns Bool, and hence stays out:

julia> gradient(x -> within_gradient(x) * x, 2.0)  # using this PR, works fine
(1.0,)

julia> gradient(x -> any(within_gradient, (x,x)) * x, 2.0)
(0.0,)

julia> gradient(x -> (within_gradient(x) || within_gradient(x)) * x, 2.0)  # could change
(1.0,)

Here ∇softmax_data does only care about one argument.

In something like https://github.com/FluxML/Flux.jl/blob/master/src/layers/normalise.jl#L109-L110 ... "training" probably means the layer's parameters are tracked, ideally, but not whether the data is.

@chengchingwen
Copy link
Member

I would say is_deriving(x)/is_differentiating(x) is kind of weird for non-tracker AD. It sounds like you are checking whether the pullback get a NoTangent as being non-differentiable. Actually, that means this function would have different semantic for different AD.

@mcabbott
Copy link
Member Author

Yes. The idea seems very clear for Tracker. But for non-tracker AD it seems a bit fragile really.

What a Flux layer wants to know is whether you are training. But what it checks is whether the forward pass calls the rrule. These seems to be no reason a smarter AD couldn't notice that the condition in if is_training() can never get a gradient, thus it need not run the rrule.

In fact it looks like Yota is smart enough to do that:

julia> Zygote.withgradient(x -> within_gradient(x) ? x^2 : x, 2.0)
(val = 4.0, grad = (4.0,))

julia> Diffractor.gradient(x -> within_gradient(x) ? x^2 : x, 2.0)
(4.0,)

julia> Yota.grad(x -> within_gradient(x) ? x^2 : x, 2.0)
(2.0, (ZeroTangent(), 1))

@chengchingwen
Copy link
Member

In fact it looks like Yota is smart enough to do that:

julia> Yota.grad(x -> within_gradient(x) ? x^2 : x, 2.0)
(2.0, (ZeroTangent(), 1))

I am a little confused. For me, it seems we are actually differentiate wrt x so I would expect within_gradient returning true?

@mcabbott
Copy link
Member Author

Paging @dfdx for a real answer, but I think it's assuming that the forward pass of rrule should always agree with the original function. Then in an expression cond ? a : b there is never a need to trace into cond and replace functions with rrules, because their pullbacks will never get any nonzero input.

@dfdx
Copy link
Contributor

dfdx commented Aug 31, 2022

but I think it's assuming that the forward pass of rrule should always agree with the original function

Correct. In Yota, rrule(f, args...)[1] giving a different result than f(args...) is kind of undefined behavior - in the latest version it works as in your example, but in the previous it would most likely take the different branch. And since Yota currently doesn't support control flow, the other branch will never show up.

One way to hide these details from Yota (and, perhaps, from all CR-based AD systems) is to add the training/differentiating flag to rrules themselves. For example:

batch_norm(...; training) = ...
rrule(batch_norm, ...; training) = ...

Since AD usually doesn't analyze the contents of the rrule itself, it should be safe to put any branching logic into it. Though, I'm not sure it won't invalidate any of the compiler optimizations in Zygote and Diffractor.

@mcabbott
Copy link
Member Author

mcabbott commented Aug 31, 2022

Thanks!

The hope with this rule is to magically infer whether a given evaluation of BatchNorm is during training or inference, even though the code for it is deep inside some model. I don't think there's an obvious way to pass the rrule a flag.

My attempt to invent a better magic rule still seems to be defeated, I guess it traces it and remembers which branch was taken, before replacing the function with the rrule.

istracked_and_val(xs...) = (any(istracked, xs), xs...)

istracked(x) = false  # overload this in Tracker

ChainRulesCore.rrule(::typeof(istracked_and_val), xs...) = (true, xs...), Tuple
julia> function dropout(x1::AbstractArray, p::Real = 0.5)
           flag, x2 = istracked_and_val(x1)
           if flag  # then we are inside a gradient call, or x isa TrackedArray
             @. ifelse(rand()>p, x2, zero(x2))
           else
             x2
           end
       end;

julia> sum(dropout(ones(10)))
10.0

julia> grad(x -> sum(dropout(x)), ones(10))
(10.0, (ZeroTangent(), [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]))

So what's left for Flux? train! could just call testmode!(m, false) on the model first. Or maybe CRC should own some special token which AD has to know to look for & replace?

@dfdx
Copy link
Contributor

dfdx commented Sep 1, 2022

Re branches in Yota. Yes, in the latest version tracing and rrule-transformation are two different steps, which turned to be a much, much more robust implementation than a unified single-step pass. Thus, tracer sees only the primal definition of istracked_and_val and follows corresponding branch.

But you don't have to trick Yota - if the consensus is to use istracked(), I can add it directly to the library.


As a general thought though, if we wish to inject some special calls into functions to account for control flow, then to me the most logical step would be to add control flow constructs themselves, e.g. like JAX's cond. I guess it would add flexibility not only to the training/testing issue, but to a number of other use cases too.

@ToucheSir
Copy link
Member

Even with something like cond, presumably it would have to be hosted in some common package so all ADs and downstream libraries can use it without taking on unwanted deps? Or do you see a way for ADs to pick up on it without explicitly checking for/defining rules for it?

@dfdx
Copy link
Contributor

dfdx commented Sep 1, 2022

Yes, definitely, there's no way to avoid a common library. My point is not to avoid the dependency, but to minimize the API that AD systems must support. Having conditionals as a part of the API is quite common and seems to help in this case too (hopefully not only for Yota :D). istracked() seems to be a new concept in the ML space, so its harder for me to analyze its corner cases. (Which doesn't imply we shouldn't explore it!)

@mcabbott
Copy link
Member Author

mcabbott commented Sep 2, 2022

It seems quite tricky. I think @non_differentiable any at present stops Zygote from looking inside any(within_gradient, (x,x)) at all. But if there's a magic token it has to obey, then it has to keep looking... when does it stop, only when it hits ccall?

Another option would be to demand that every AD system set some task_local_storage, a bit like how @allowscalar works:

https://github.com/JuliaGPU/GPUArrays.jl/blob/master/lib/GPUArraysCore/src/GPUArraysCore.jl#L40-L112

I believe that's a global namespace, so different packages could actually all set task_local_storage(:InsideReverseModeGrad, true) without depending on a common package.

@ToucheSir
Copy link
Member

RE task local state, my understanding is that it doesn't propagate to child tasks and would thus could be a problem for differentiating async code. That and the risk of the AD not cleaning task-local state up correctly on errors, leading to false positives for code outside of AD.

istracked() seems to be a new concept in the ML space, so its harder for me to analyze its corner cases. (Which doesn't imply we shouldn't explore it!)

I was under the impression that something like cond was newer? When digging for #434 (comment), it seems like PyTorch has had a "are we in AD" function for quite some time to help with choosing derivative routines.

@mcabbott
Copy link
Member Author

mcabbott commented Sep 2, 2022

This issue of branches within AD may also deserve some though in relation to JuliaDiff/ForwardDiff.jl#480 . That's at last fixed on ForwardDiff, where it seemed to cause the most problems as e.g. you push numbers through det rather than having a rule for it. But if you do have a branch based on values inside AD, the other systems show the same bad behaviour:

julia> gradient(prod1, x)[1]  # correct
1×5 Matrix{Float32}:
 0.0  0.0  24.0  0.0  0.0

julia> Diffractor.gradient(prod2, x)[1]  # wrong
1×5 Matrix{Float32}:
 0.0  0.0  2.0  0.0  0.0

# Same for Tracker, Zygote, Enzyme, but no longer for ForwardDiff.

julia> ChainRulesCore.@non_differentiable eltype(::Any)

julia> Yota.grad(prod2, x)  # worse? 
(0.0f0, (ZeroTangent(), NoTangent()))

@dfdx
Copy link
Contributor

dfdx commented Sep 2, 2022

I was under the impression that something like cond was newer? When digging for #434 (comment), it seems like PyTorch has had a "are we in AD" function for quite some time to help with choosing derivative routines.

Symbolic condition operator existed in TensorFlow 1 (created in 2015), Theano (created in 2008) and perhaps even earlier. I think my cognitive resistance to accept equivalence of istracked() and PyTorch's GradMode::is_enabled() comes from the difference between symbolic (Theano, TensorFlow, JAX) and overload-based (PyTorch) families of ADs. As far as can say, Zygote, Diffractor and Yota analyze code as a symbolic graph, so using tools from such systems (e.g. cond) instead of the differentiation flag looks more natural to me. But maybe that's just my background - I'm pretty curious to see how istracked() works in this context anyway.

It seems quite tricky. I think @non_differentiable any at present stops Zygote from looking inside any(within_gradient, (x,x)) at all.

This particular example doesn't sound lie a problem. Given a function:

function foo()
    if any(within_gradient, (x,x)) 
        return f(...)
    else
        return g(...)
    end
end

Zygote should be able to analyze it's IR and replace all calls to within_gradient(...) with just true (or equivalent). Of course, it's possible to nest within_gradient() into another auxiliary function, e.g.:

bar(x) = any(within_gradient, (x,x)) 

function foo()
    if bar(x)
        return f(...)
    else
        return g(...)
    end
end

But for me it looks fair enough to just forbid such usage.

julia> Yota.grad(prod2, x) # worse?
(0.0f0, (ZeroTangent(), NoTangent()))

Thanks for the report! I've created an issue to track it.

@ToucheSir
Copy link
Member

Thanks for the correction, I misremembered those operations as behaving more like ifelse than actually being able to lazily evaluate branch subgraphs.

This particular example doesn't sound lie a problem...

Currently that example will hit https://github.com/JuliaDiff/ChainRules.jl/blob/v1.44.5/src/rulesets/Base/nondiff.jl#L99, so AD never gets to generate a pullback for within_grad. Zygote could in theory do enough constant prop + folding to realize that any(within_gradient, ...) == true, but in practice I don't think anyone is interested in writing one (using IRTools, because we're still stuck with that) that matches Julia's built-in behaviour from scratch!

So a cond-like construct could be the path of least resistance. It does feel somewhat unidiomatic to write out each branch as a function instead of using a conditional, but at least for Zygote there is value in hiding said conditional where AD can't see it. My main concern would be whether adding an extra 2+ functions to compile for each branch like this negatively impacts latency, but that can be easily tested once a prototype exists.

@mcabbott mcabbott marked this pull request as draft September 4, 2022 20:23
@mcabbott
Copy link
Member Author

mcabbott commented Jan 3, 2023

Shall we touch this up & merge it? It's not perfect, but nor is what we have right now.

@ToucheSir
Copy link
Member

Let's do this.

@dfdx, do you mind fleshing that cond idea out a bit more and linking it back here once done? I wouldn't want to have it lost after this PR is done. Location of the writeup doesn't matter, but if you're looking for one I opened JuliaDiff/AbstractDifferentiation.jl#66 a while back.

@CarloLucibello
Copy link
Member

We can add this here, but ChainRulesCore seems a better place.

@mcabbott
Copy link
Member Author

mcabbott commented Jan 5, 2023

One reason not ChainRules is that we'd like this to work for ForwardDiff (FluxML/Flux.jl#2122) and Tracker (FluxML/Tracker.jl#133), and that's really the point of moving it from Flux, and neither depend on CR.

Today's 744309a adds ForwardDiff via Requires.

Another is that this is a bit of a dodgy mechanism, it's what Flux uses but really needs tests for each use to make sure AD hasn't out-smarted us.

@mcabbott mcabbott marked this pull request as ready for review January 5, 2023 11:45
@mcabbott mcabbott requested a review from ToucheSir January 5, 2023 16:39
@CarloLucibello
Copy link
Member

CarloLucibello commented Jan 5, 2023

One reason not ChainRules is that we'd like this to work for ForwardDiff (FluxML/Flux.jl#2122) and Tracker (FluxML/Tracker.jl#133), and that's really the point of moving it from Flux, and neither depend on CR.

Today's 744309a adds ForwardDiff via Requires.

Why don't we make Tracker and ForwardDiff depend on CRC (if not CR altogether)?

@mcabbott
Copy link
Member Author

mcabbott commented Jan 5, 2023

Tracker could, no big deal. ForwardDiff I think is not interested.

But the more important half is that this is a hacky solution, which is no worse than Flux's present state.

@CarloLucibello
Copy link
Member

What about DiffRules.jl? I don't know what's its purpose but Zygote, ForwardDiff and Tracker depend on it.

@mcabbott
Copy link
Member Author

mcabbott commented Jan 5, 2023

Seems like that's not its mission -- old stable package that does exactly one thing.

And, same problem, that this is a bit of a hack which I don't think should be encouraged as the permanent canonical solution. But, absent better ideas, it's what Flux does right now, and this PR lets us fix FluxML/Tracker.jl#133 and FluxML/Flux.jl#2122 .

What are the downsides? Besides not really being the permanent canonical solution, because we don't have one.

@CarloLucibello
Copy link
Member

yeah maybe I was thinking of a more permanent solution around which the ecosystem can grow instead of hacking here and there. It just seems odd to have this function here, and we also have to add Requires. Anyways it's fine, no big deal, and better than where we were.
Btw, I just discovered https://github.com/ThummeTo/ForwardDiffChainRules.jl

@mcabbott
Copy link
Member Author

mcabbott commented Jan 5, 2023

Yes something better would be nice. But I honestly don't see how; in the example any really blocks Zygote in a way that seems hard to picture a way around.

@mcabbott mcabbott merged commit 5f63dbf into FluxML:master Jan 5, 2023
@mcabbott mcabbott deleted the istraining branch January 5, 2023 19:29
@ToucheSir
Copy link
Member

ToucheSir commented Jan 5, 2023

IMO the current best place for a long-term solution is JuliaDiff/AbstractDifferentiation.jl#66. But we don't want to wait for that package to mature just to get Tracker.jl working for basic Flux models, so this is a sensible stopgap.

@mcabbott
Copy link
Member Author

mcabbott commented Jan 5, 2023

We can make the warning scarier if we'd like to discourage use outside of Flux etc.

@mcabbott
Copy link
Member Author

Xref a similar function now added to EnzymeCore: EnzymeAD/Enzyme.jl#1851

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants