-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Orthogonal initialization feature. #1496
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know the details of this initialization algorithm, so I have no idea whether this is coded correctly.
I made some format and style suggestions to it, hope you don't mind. Feel free to ignore it if you think it doesn't make sense.
Edit: tests cases are required.
if rows < cols | ||
Q = transpose(Q) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another small one-liner trick and feel free to take any of it, or just ignore it.
if rows < cols | |
Q = transpose(Q) | |
end | |
Q = rows < cols ? transpose(Q) : Q |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I should keep my thing, looks more elegant 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reason this strikes several of us as weird is partly that it's not type-stable to re-use Q, not just for different things, but for different types depending on the values of rows, cols
. This isn't performance-critical code but that's where everyone's taste was honed.
Again, I would write
return rows > cols ? gain .* M : gain .* transpose(M)
where M
is some name for the thing which isn't Q
anymore, and the two branches match the branches which generate the random numbers above. They could both be written out on several lines, mat = if rows > cos; randn(...
etc, but however they are written, I think they should put the then/else clauses in the same order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the mirrored if-else clause are a bit confusing. Should change that if nothing else.
src/utils.jl
Outdated
end | ||
|
||
rows = dims[1] | ||
cols = mapreduce(x->x, *, dims; init=1) ÷ rows |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just div(prod(dims), rows)
right?
But why does reshaping this orthogonal matrix make sense? When would that be desirable -- instead of (say) orthogonal per slice, or just an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Orthogonal per slice could be done. That is a nice idea indeed.
In the PyTorch implementation, they flatten the matrix if the dimensions are greater than 2. I thought of implementing a similar thing.
What do you think would be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I met this idea 5 minutes ago, and haven't read the paper. It seems plausible that orthogonal matrices might be somehow interesting in Dense layers. That only needs ndims=2, and the simplest option would be just to restrict to that. (Would also guarantee that later changes to how you treat 4 dims won't break anything.)
For a Conv layer, my guess would be that orthogonal in the channel dimensions (i.e. the last 2 dims of W). But I've no idea if that's something anyone does. Or possibly they do other things? Could leave until someone needs it.
I wonder why they reshape things in Pytorch, I would have thought the result didn't have any special properties.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you on the Conv layer point. I believe it should be orthogonal channel-wise. Let me open a new issue on it at the PyTorch repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. For now, perhaps this should be restricted to matrices? Justorthogonal_rang(rng, rows, cols; gain)
. This method is probably what you would want to call repeatedly to make higher-dim ones, in any case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rely on the PyTorch implementation since a large number of people use it. We can always change it later as we feel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not blindly copy them. But if it turns out they have a reason, that may be fine.
Added examples rather than plain code. Co-authored-by: Michael Abbott <[email protected]>
Co-authored-by: Michael Abbott <[email protected]>
rename to |
@CarloLucibello Since the other initialization techniques do not have any init keyword in their function name, I think it would be better to keep it |
Could someone help me find where the docstring is going wrong? I am consistently getting the error |
I also thing the name should be more than |
Since it behaves a lot like |
@darsnack @DhairyaLGandhi I have made the required changes. All builds passing, only doctest failing which also seems to be from some other set of tests. |
@SomTambe No problem, thanks for putting the effort in to see this through. The only remaining changes are the docstring changes. Do you mind making those? |
Let me try figuring out what has gone wrong. If you get time, can you mark out in brief what needs to be changed? If you are occupied (which I expect), the general drill will be that I will take some time (I guess till tomorrow morning, since it is night here in India) and I will make the changes 😃 . |
@darsnack Let me try figuring out what has gone wrong. If you get time, can you mark out in brief what needs to be changed? If you are occupied (which I expect), the general drill will be that I will take some time (I guess till tomorrow morning, since it is night here in India) and I will make the changes 😃 . |
They should be marked out in my last review above ☝🏾 |
Don't worry about the doctest errors btw. The issue is addressed in another PR, we'll either just need to rebase, or we'll merge this PR with the failing test (knowing that it is already fixed in another PR). |
Sorry, one more thing! You need to add |
Add the See Also section. Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
A very final attempt in favor of the |
Co-authored-by: Michael Abbott <[email protected]>
Co-authored-by: Michael Abbott <[email protected]>
@darsnack Made the required changes. I think we should go ahead with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still want to switch the docstrings to cross-reference the initialization sections of the docs instead of all-to-all with the other initialization functions.
But this PR is good to go, and we can make that change in a separate PR. Final review from @DhairyaLGandhi needed.
@DhairyaLGandhi Could you review the request? Maybe merge it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks great!
bors r+
Build succeeded: |
As per issue JuliaLang/julia#1431 I have added the Orthogonal matrix initialization feature.
I will add the tests gradually. Just wondering what they can be.
PR Checklist
@dhairyagandhi96
(for API changes).