-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Bilinear layer #1009
Added Bilinear layer #1009
Conversation
Seems like a reason to want FluxML/NNlib.jl#100 But even without that, I think it's not hard to avoid the splats, and go almost 100x faster. Some scribbles here: https://gist.github.com/mcabbott/29cc74f287a95724d6f561f4ed285624 |
Cool stuff, didn't know about |
I updated the gist, now OMEinsum's However I'm not sure that Flux wants to depend on that package, so I'm not so sure what the best answer is. |
If we can't add that dependency we can just fallback to your previous implementation, thanks for the commit @mcabbott! |
@mcabbott where does |
Oh, is it Edi: It's |
Yeah, I tend to stick to the LTS versions, I'll check it |
The only error remaining is about I went ahead and removed its type annotations, but it'd be better to put other, suitable ones in place. |
eachcol is think is Julia v1.1+ only, so will fail on earlier versions |
I got a new implementation working now though, using
|
Actually, I"m having trouble I think you might need to add a multi-arg
or something. I see
|
@dhairyagandhi96 can this one be merged? |
@bhvieira Any reason to not add the |
@arnavs I think you deleted a comment or something? Didn't see your suggestion for some reason, I might look into it, but could do it in another PR as well |
Yeah, I'd made an earlier comment just asking if this was ready to merge. And then a follow up with the bug report. Thanks for looking into it. Basically we just need chains to act on two arguments, otherwise you can't use a bilinear layer as the first in a chain. So those two lines work for me, but perhaps there are better ways. |
I fixed that issue @arnavs without touching |
Note that |
@mcabbott Gosh this never stops haha. It's cool that we can rely on |
Sorry, no insult intended, if that came off wrong, I'm guilty of earlier hacks. But this is a common operation which, like function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
W, b, σ = a.W, a.b, a.σ
d_z, d_x, d_y = size(W)
d_x == size(x,1) && d_y == size(y,1) || throw(DimensionMismatch("number of rows in data must match W"))
size(x,2) == size(y,2) || throw(DimensionMismatch("data inputs must agree on number of columns"))
# @einsum Wy[o,i,s] := W[o,i,j] * y[j,s]
Wy = reshape(reshape(W, (:, d_y)) * y, (d_z, d_x, :))
# @einsum Z[o,s] := Wy[o,i,s] * x[i,s]
Wyx = batched_mul(Wy, reshape(x, (d_x, 1, :)))
Z = reshape(Wyx, (d_z, :))
# @einsum out[o,s] := σ(Z[o,i] + b[o])
σ.(Z .+ b)
end |
With the timely PRs by @mcabbott, I think we are set here and the functionality is better than ever. Is there anything else you think we should do here @dhairyagandhi96? |
Btw, should it be exported? Similarly "uncommon" functionalities aren't exported, so I did not include it, but I can add it you deem it useful. |
looks good! I would leave it unexported |
Would the |
could be. I didn't even know it existed though. I'll just remove the test |
I really hope this goes green, this commit suggestion thing is becoming painful 😅 |
victory! bors r+ |
@CarloLucibello thanks for the efforts haha. I had no idea a simple equality test between gpu and cpu would take so much. Are gpu gradients stored as gpu arrays? Perhaps if we moved it back to the cpu it would've worked. |
bors r+ |
bors r- |
bors r+ |
@DhairyaLGandhi maybe you should just merge manually here |
bors r+ |
1009: Added Bilinear layer r=CarloLucibello a=bhvieira A basic implementation inspired on https://pytorch.org/docs/stable/nn.html#bilinear I haven't exported it, because I think this layer is a bit more esoteric compared with others. It basically computes interactions between two sets of inputs. I thought about augmenting it to also include the non-interaction terms (this can easily be done, eg. augmenting the data with a row of ones) but for now it simply mirrors PyTorch's one. I had to use splatting `vcat(x...)` and `hcat(x...)` in the forward pass. I wanted to avoid it, but with `reduce` I couldn't get gradients. But I think this can be improved. Co-authored-by: Bruno Hebling Vieira <[email protected]> Co-authored-by: Michael Abbott <[email protected]>
This PR was included in a batch that successfully built, but then failed to merge into master (it was a non-fast-forward update). It will be automatically retried. |
A basic implementation inspired on https://pytorch.org/docs/stable/nn.html#bilinear
I haven't exported it, because I think this layer is a bit more esoteric compared with others.
It basically computes interactions between two sets of inputs.
I thought about augmenting it to also include the non-interaction terms (this can easily be done, eg. augmenting the data with a row of ones) but for now it simply mirrors PyTorch's one.
I had to use splatting
vcat(x...)
andhcat(x...)
in the forward pass. I wanted to avoid it, but withreduce
I couldn't get gradients. But I think this can be improved.