-
-
Notifications
You must be signed in to change notification settings - Fork 808
Memory efficient backward #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
TimDettmers
merged 70 commits into
bitsandbytes-foundation:main
from
dbaranchuk:memory-efficient-backward
Sep 20, 2022
Merged
Changes from all commits
Commits
Show all changes
70 commits
Select commit
Hold shift + click to select a range
8ae9bb2
add memory efficient backward
dbaranchuk 1753aa0
refactoring
dbaranchuk 656de8e
minor fixes
dbaranchuk 876387d
minor fixes
dbaranchuk ef2936a
delete CxB from state
dbaranchuk 4d6174b
memory efficient fp16 backward
dbaranchuk b3fee1e
add dtype <-> fp16 cast
dbaranchuk 8d34d36
req_gradA for casted & more efficient and accurate fp16 backward
dbaranchuk 843ad06
Merge pull request #1 from TimDettmers/main
dbaranchuk 42b5fc9
add memory effcient backward option
dbaranchuk ee325f0
clarified an exception message
dbaranchuk d358999
refactoring
dbaranchuk 4dd475c
refactoring
dbaranchuk e2a7576
bug fix
dbaranchuk 3634fc7
Merge branch 'TimDettmers:main' into memory-efficient-backward
justheuristic cc4858c
some kind of warning or something when this is first executed to make…
justheuristic 469d5a6
test_bf16
justheuristic a9c7953
cast to half before double_quant
justheuristic 140cdbe
check dtypes first
justheuristic 9379df8
check dtypes first
justheuristic e29c5f5
clearer assertions
justheuristic fc4a135
clearer assertions
justheuristic a9fe0ff
recast to fp16
justheuristic eac9aca
cast bias too
justheuristic 7facedd
copypaste tolerances
justheuristic d9ca0ed
un-fuse bias
justheuristic 56a074f
un-fuse bias
justheuristic e9b8711
un-fuse bias
justheuristic 0de1a44
change order
justheuristic 647c976
change order
justheuristic 210b9ed
debug assert
justheuristic 85bf529
debug assert
justheuristic e2b523d
change typecast behavior
justheuristic d6e25b5
change typecast behavior
justheuristic 1145589
change typecast behavior
justheuristic 1da4880
change typecast behavior
justheuristic 5b169f1
change typecast behavior
justheuristic 14048a3
safer cast
justheuristic a214824
matmul -1- addmm
justheuristic 702cc72
debug asset
justheuristic 45dc198
cast properly
justheuristic 577275b
cast properly
justheuristic e35e2c6
cast properly
justheuristic cbfdf0b
cast edge case
justheuristic ab9dee0
cast edge case
justheuristic fa8e07c
more lenient threshold
justheuristic f667032
bump threshold to 0.21
justheuristic 18f142e
addmm_
justheuristic 76ece2c
rollback
justheuristic 579b8c7
reduce diff
justheuristic 591f603
add memory efficient backward
justheuristic 2cd047e
run backward
justheuristic 7906dc4
debugpritn
justheuristic 4b4a9ef
debugprint
justheuristic 4da2227
debug
justheuristic 5d65817
debug
justheuristic d9b8789
debug
justheuristic 6a826c4
pre-cast
justheuristic 37f805b
debug
justheuristic 95dafc6
cast before allclose
justheuristic 28a9313
cast before allclose
justheuristic 725cc72
cast device
justheuristic e4086a2
cast device
justheuristic 01b4c6a
cast device
justheuristic 32a9a88
cast device
justheuristic cff3a71
cast device
justheuristic 9b7d307
review
TimDettmers a07825a
review
justheuristic 292a478
set threshold
TimDettmers 76ce9aa
try fp32
justheuristic File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curiously, when i tried to replace the next line from
output += torch.matmul(subA, state.subB)to
output.addmm_(subA, state.subB)the precision would drop and the tests would fail.
I have no idea why - the dtypes of output, subA and subB are always equal (tested).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot remember if I stumbled upon the same thing. I remember trying to make this matrix multiplication more efficient but failed. What is the increase that you see in errors?
It does not make much sense to me since in cuBLAS you perform (A @ B) + D = C and the results of A @ B is in fp32 so the entire operation should be more precise. The same goes for fused multiply-add in general, which is more precise than multiplication followed by addition. It might be some weird tensor core issue, but it makes no sense to me.
If the error is only smaller some of the time and it has more variance, it would still be okay to have this. I believe it would be a good chunk faster.