Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Orthogonal initialization feature. #1496

Merged
merged 39 commits into from
Feb 11, 2021
Merged

Add Orthogonal initialization feature. #1496

merged 39 commits into from
Feb 11, 2021

Conversation

SomTambe
Copy link
Member

@SomTambe SomTambe commented Feb 3, 2021

As per issue JuliaLang/julia#1431 I have added the Orthogonal matrix initialization feature.

I will add the tests gradually. Just wondering what they can be.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • Final review from @dhairyagandhi96 (for API changes).

Copy link
Contributor

@johnnychen94 johnnychen94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the details of this initialization algorithm, so I have no idea whether this is coded correctly.

I made some format and style suggestions to it, hope you don't mind. Feel free to ignore it if you think it doesn't make sense.

Edit: tests cases are required.

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
Comment on lines +215 to +217
if rows < cols
Q = transpose(Q)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another small one-liner trick and feel free to take any of it, or just ignore it.

Suggested change
if rows < cols
Q = transpose(Q)
end
Q = rows < cols ? transpose(Q) : Q

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should keep my thing, looks more elegant 😄

Copy link
Member

@mcabbott mcabbott Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason this strikes several of us as weird is partly that it's not type-stable to re-use Q, not just for different things, but for different types depending on the values of rows, cols. This isn't performance-critical code but that's where everyone's taste was honed.

Again, I would write

return rows > cols ? gain .* M : gain .* transpose(M)

where M is some name for the thing which isn't Q anymore, and the two branches match the branches which generate the random numbers above. They could both be written out on several lines, mat = if rows > cos; randn(... etc, but however they are written, I think they should put the then/else clauses in the same order.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the mirrored if-else clause are a bit confusing. Should change that if nothing else.

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated
end

rows = dims[1]
cols = mapreduce(x->x, *, dims; init=1) ÷ rows
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just div(prod(dims), rows) right?

But why does reshaping this orthogonal matrix make sense? When would that be desirable -- instead of (say) orthogonal per slice, or just an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Orthogonal per slice could be done. That is a nice idea indeed.
In the PyTorch implementation, they flatten the matrix if the dimensions are greater than 2. I thought of implementing a similar thing.
What do you think would be better?

Copy link
Member

@mcabbott mcabbott Feb 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I met this idea 5 minutes ago, and haven't read the paper. It seems plausible that orthogonal matrices might be somehow interesting in Dense layers. That only needs ndims=2, and the simplest option would be just to restrict to that. (Would also guarantee that later changes to how you treat 4 dims won't break anything.)

For a Conv layer, my guess would be that orthogonal in the channel dimensions (i.e. the last 2 dims of W). But I've no idea if that's something anyone does. Or possibly they do other things? Could leave until someone needs it.

I wonder why they reshape things in Pytorch, I would have thought the result didn't have any special properties.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you on the Conv layer point. I believe it should be orthogonal channel-wise. Let me open a new issue on it at the PyTorch repo.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. For now, perhaps this should be restricted to matrices? Justorthogonal_rang(rng, rows, cols; gain). This method is probably what you would want to call repeatedly to make higher-dim ones, in any case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rely on the PyTorch implementation since a large number of people use it. We can always change it later as we feel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not blindly copy them. But if it turns out they have a reason, that may be fine.

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

rename to orthogonal_init or init_orthogonal?

@SomTambe
Copy link
Member Author

SomTambe commented Feb 3, 2021

@CarloLucibello Since the other initialization techniques do not have any init keyword in their function name, I think it would be better to keep it orthogonal just to be consistent.

src/utils.jl Outdated Show resolved Hide resolved
@SomTambe
Copy link
Member Author

SomTambe commented Feb 3, 2021

Could someone help me find where the docstring is going wrong? I am consistently getting the error Base.Meta.ParseError("unexpected \",\"") here in my docstring, when TravisCI is trying to build my documentation. Thank you!

@darsnack
Copy link
Member

darsnack commented Feb 3, 2021

I also thing the name should be more than orthogonal. The other init functions like "glorot" or "kaiming" are unique to ML parameter initialization but "orthogonal" is used in many contexts. sparse_init is an analogue where a more descriptive name is required.

@mcabbott
Copy link
Member

mcabbott commented Feb 3, 2021

Since it behaves a lot like randn, perhaps something like rand_orthogonal or orthogonal_rand?

@SomTambe
Copy link
Member Author

SomTambe commented Feb 8, 2021

@darsnack @DhairyaLGandhi I have made the required changes. All builds passing, only doctest failing which also seems to be from some other set of tests.
I was busy in this weekend, sorry for the delay.

@darsnack
Copy link
Member

darsnack commented Feb 8, 2021

@SomTambe No problem, thanks for putting the effort in to see this through. The only remaining changes are the docstring changes. Do you mind making those?

@SomTambe
Copy link
Member Author

SomTambe commented Feb 8, 2021

@SomTambe No problem, thanks for putting the effort in to see this through. The only remaining changes are the docstring changes. Do you mind making those?

Let me try figuring out what has gone wrong. If you get time, can you mark out in brief what needs to be changed? If you are occupied (which I expect), the general drill will be that I will take some time (I guess till tomorrow morning, since it is night here in India) and I will make the changes 😃 .

@SomTambe
Copy link
Member Author

SomTambe commented Feb 8, 2021

@SomTambe No problem, thanks for putting the effort in to see this through. The only remaining changes are the docstring changes. Do you mind making those?

@darsnack Let me try figuring out what has gone wrong. If you get time, can you mark out in brief what needs to be changed? If you are occupied (which I expect), the general drill will be that I will take some time (I guess till tomorrow morning, since it is night here in India) and I will make the changes 😃 .

@darsnack
Copy link
Member

darsnack commented Feb 8, 2021

They should be marked out in my last review above ☝🏾

@darsnack
Copy link
Member

darsnack commented Feb 8, 2021

Don't worry about the doctest errors btw. The issue is addressed in another PR, we'll either just need to rebase, or we'll merge this PR with the failing test (knowing that it is already fixed in another PR).

@darsnack
Copy link
Member

darsnack commented Feb 8, 2021

You'll also need to add orthogonal as a reference in all the other init function's docstrings. See @mcabbott's suggestion here.

@darsnack
Copy link
Member

darsnack commented Feb 8, 2021

Sorry, one more thing! You need to add orthogonal to docs/src/utilities.md.

SomTambe and others added 5 commits February 8, 2021 17:22
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

A very final attempt in favor of the orthogonal_init name, feel free to ignore. Would you ever give this PR the title "Add Orthogonal feature"? I guess not, because it would convey little information about its content. Same for Flux.orthogonal, someone looking at the code wouldn't guess what would be the purpose of this function without looking at the docs

@SomTambe
Copy link
Member Author

SomTambe commented Feb 9, 2021

@darsnack Made the required changes. I think we should go ahead with orthogonal. All tests passed finally!

@SomTambe SomTambe requested a review from darsnack February 10, 2021 08:58
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still want to switch the docstrings to cross-reference the initialization sections of the docs instead of all-to-all with the other initialization functions.

But this PR is good to go, and we can make that change in a separate PR. Final review from @DhairyaLGandhi needed.

@SomTambe
Copy link
Member Author

@DhairyaLGandhi Could you review the request? Maybe merge it?

Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks great!

bors r+

@CarloLucibello CarloLucibello mentioned this pull request Feb 11, 2021
92 tasks
@bors
Copy link
Contributor

bors bot commented Feb 11, 2021

Build succeeded:

@bors bors bot merged commit 4c53672 into FluxML:master Feb 11, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants