Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

16bit + Deterministic Training #123

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Conversation

BardiaKh
Copy link

I have prepared this pull request to make the package more compatible with 16bit training and deterministic training.

In the models.utils.log_softplus I added a simple line to save the input data type and convert it back to that specific type after indexing. Previously, the code assumed float32 type as the only possible option; hence not compatible with half-precision training.

Moreover, most loss functions use .gather() to get the corresponding duration index from a tensor, and .gather() is not a deterministic algorithm which makes reproducibility of the results a nightmare, if not impossible. I created a function called replace_gather that does the same thing using flattening and torch.index_select() which are deterministic.

I checked the code compared with the old version and they produce exact same loss values. I would appreciate your input on this PR.

Regards,

@BardiaKh BardiaKh closed this Jan 18, 2022
@BardiaKh BardiaKh reopened this Jan 18, 2022
@BardiaKh BardiaKh changed the title 16 Bit Training + Deterministic Approach 16bit + Deterministic Training Jan 18, 2022
@havakv
Copy link
Owner

havakv commented Jan 20, 2022

Hi! It's great that you're able to contribute! It looks like some of the test here failed. You should try to use pytest on your local machine to make sure all of them pass. You should be able to install all the dev dependencies with pip install -r requirements-dev.txt, and then you can run the tests with the command pytest in the root of the pycox repository.

Second, it looks like some of your commits are just making alterations to the same softplus function (add something, then removes it again, etc). You should probably clean up commits such that they only include the relevant changes. This makes it a lot easier for the reviewer. If you're not familiar with changing the commit history, you should look up git rebase and a tutorial on how to use it.

Update utils.py

Update utils.py

Update utils.py

Update utils.py
Update loss.py
@BardiaKh
Copy link
Author

Dear @havakv,
Thanks for your advice. I cleaned up the commit history and also ran local tests using pytest. Please let me know if there is anything else I should do.

@havakv
Copy link
Owner

havakv commented Jan 21, 2022

Hmm, still looks like the tests fail

def replace_gather(tensor: Tensor, dim: int, idx: Tensor) -> Tensor:
        tensor_shape = tensor.shape
        tensor_ = tensor.flatten()
>       idx_ = idx.flatten() + torch.arange(start=0, end=tensor_shape[0]*tensor_shape[1], step=tensor_shape[1], device = idx.get_device())
E       RuntimeError: Device index must not be negative```

@havakv
Copy link
Owner

havakv commented Jan 23, 2022

Still failing unfortunately

@BardiaKh
Copy link
Author

This is really odd. Because I'm training my model using the exact same code.
Is there a way I can test it locally so I can crack down on the issue? I don't want to bother you with all the PRs and test failures.

I truly appreciate your patience.

@havakv
Copy link
Owner

havakv commented Jan 29, 2022

So, all the test are just run on CPU (no gpu for now) and the script is just running the pytest command as you're doing locally. If you click on details on the tests, you should see output in the same format as you get locally:

============================= test session starts ==============================
platform linux -- Python 3.6.15, pytest-6.2.5, py-1.11.0, pluggy-1.0.0
rootdir: /home/runner/work/pycox/pycox
collected 198 items

tests/test_utils.py ..........                                           [  5%]
tests/evaluation/test_admin.py ....                                      [  7%]
tests/models/test_bce_surv.py ....                                       [  9%]
tests/models/test_cox.py ..                                              [ 10%]
tests/models/test_cox_cc.py ..                                           [ 11%]
tests/models/test_cox_time.py ..                                         [ 12%]
tests/models/test_deephit.py FFFF                                        [ 14%]
tests/models/test_interpolation.py ..............................        [ 29%]
tests/models/test_logistic_hazard.py FFFF                                [ 31%]
tests/models/test_loss.py FFFFFFFFF...........................FFFFFFFFFF [ 54%]
FFFFFFFFFFFFFF............F                                              [ 68%]
tests/models/test_models_utils.py ...................................    [ 85%]
tests/models/test_mtlr.py FFFF                                           [ 87%]
tests/models/test_pc_hazard.py ............FFFFFFFF                      [ 97%]
tests/models/test_pmf.py FFFF                                            [100%]

=================================== FAILURES ===================================
______________________ test_deep_hit_single_runs[2-True] _______________________

If you get all the tests to pass locally (no F's), and they are still failing here, I'll try to run the code and see if I can reproduce some of the failures. It might be down to the pytorch version. Currently the tests run the 1.5 (which is a bit old). You can see the automatic test setup here https://github.com/havakv/pycox/blob/master/.github/workflows/pythonpackage.yml and you can try to edit the version of pytorch if you think that is the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants