-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
16bit + Deterministic Training #123
base: master
Are you sure you want to change the base?
Conversation
Hi! It's great that you're able to contribute! It looks like some of the test here failed. You should try to use pytest on your local machine to make sure all of them pass. You should be able to install all the dev dependencies with Second, it looks like some of your commits are just making alterations to the same softplus function (add something, then removes it again, etc). You should probably clean up commits such that they only include the relevant changes. This makes it a lot easier for the reviewer. If you're not familiar with changing the commit history, you should look up |
Update utils.py Update utils.py Update utils.py Update utils.py
Update loss.py
Dear @havakv, |
Hmm, still looks like the tests fail
|
Still failing unfortunately |
This is really odd. Because I'm training my model using the exact same code. I truly appreciate your patience. |
So, all the test are just run on CPU (no gpu for now) and the script is just running the pytest command as you're doing locally. If you click on details on the tests, you should see output in the same format as you get locally:
If you get all the tests to pass locally (no F's), and they are still failing here, I'll try to run the code and see if I can reproduce some of the failures. It might be down to the pytorch version. Currently the tests run the 1.5 (which is a bit old). You can see the automatic test setup here https://github.com/havakv/pycox/blob/master/.github/workflows/pythonpackage.yml and you can try to edit the version of pytorch if you think that is the issue |
I have prepared this pull request to make the package more compatible with 16bit training and deterministic training.
In the
models.utils.log_softplus
I added a simple line to save the input data type and convert it back to that specific type after indexing. Previously, the code assumedfloat32
type as the only possible option; hence not compatible with half-precision training.Moreover, most loss functions use
.gather()
to get the corresponding duration index from a tensor, and.gather()
is not a deterministic algorithm which makes reproducibility of the results a nightmare, if not impossible. I created a function calledreplace_gather
that does the same thing using flattening andtorch.index_select()
which are deterministic.I checked the code compared with the old version and they produce exact same loss values. I would appreciate your input on this PR.
Regards,