Skip to content

Conversation

@hasansalimkanmaz
Copy link
Contributor

@hasansalimkanmaz hasansalimkanmaz commented Apr 23, 2022

What does this PR do?

This PR ensures reproducibility for distributed trainings by setting seed for worker in dataloader and setting environment variables for cuda.

This PR is motivated by this issue.

Who can review?

@saattrupdan @sgugger I am looking forward to your feedback

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 23, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@saattrupdan saattrupdan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investing your time to implement this PR! 😊

I have mostly small changes related to documentation and naming, but otherwise looks good 👍

EDIT: To enable support for Tensorflow models, you could add use the enable_op_determinism in the Tensorflow case.

torch.backends.cudnn.benchmark = False


def set_seed(seed: int, set_seed_for_cuda: bool = True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to the function name above, I'd argue that the argument here should be changed to something like enable_determinism. Further, I'd make the default False, as enabling it can cause weird errors, if one uses algorithms that don't have a deterministic variant yet.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work on this!

tf.config.experimental.enable_op_determinism()


def set_seed(seed: int, enable_determinism: bool = True):
Copy link
Collaborator

@sgugger sgugger Apr 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def set_seed(seed: int, enable_determinism: bool = True):
def set_seed(seed: int, full_determinism: bool = False):

I like full_determinism a bit better. Since this is a new addition, the default should be set to False. Although it does fix what one might consider a bug, so I'm not sure on this one. @LysandreJik do you have an opinion?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this, that's an important feature! So as to not introduce a breaking change, and for clarity of the API, I'd personally vouch for not adding the enable_determinism flag to the set_seed method.

From the title of the method I understand it should set the seed, and that's it. I don't think it should do anything else. However, the enable_determinism_for_distributed_training method likely needs the seed to be set in order to benefit from full determinism, so I'd even push to have the set_seed method called inside the enable_determinism_for_distributed_training, adding a seed argument to that last method.

What do you think?

@hasansalimkanmaz
Copy link
Contributor Author

hasansalimkanmaz commented Apr 27, 2022

Thanks for working on this, that's an important feature! So as to not introduce a breaking change, and for clarity of the API, I'd personally vouch for not adding the enable_determinism flag to the set_seed method.

From the title of the method I understand it should set the seed, and that's it. I don't think it should do anything else. However, the enable_determinism_for_distributed_training method likely needs the seed to be set in order to benefit from full determinism, so I'd even push to have the set_seed method called inside the enable_determinism_for_distributed_training, adding a seed argument to that last method.

What do you think?

I like this idea. I can implement it after we reach a conclusion on it, however, it is not clear to me how to implement it. Could you point me to which parts of the code I need to change/pay attention not to break anything if we decide to go for this idea?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some pointer on what @LysandreJik suggests.

set_seed(worker_seed)


def enable_determinism_for_distributed_training():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea would be for this function to take seed here.

Suggested change
def enable_determinism_for_distributed_training():
def enable_full_determinism(seed: int):

and then call set_seed inside (instead of set_seed calling this function).

(Also changing the name to be a bit shorter.)

Comment on lines 95 to 96
if enable_determinism:
enable_determinism_for_distributed_training()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And so this part would disappear here, it would be the other way around.

@hasansalimkanmaz
Copy link
Contributor Author

@sgugger Thanks for the pointers and sorry for not being so clear. I would like to know in which places enable_full_determinism should be called. Currently, set_seed is called several places in the codebase. I don't think these calls will be replaced with enable_full_determinism.

With the latest commits, I have already addressed your pointers. Now I am waiting your feedback for where to call enable_full_determinism in the codebase. It is not called any place in the codebase right now.

@sgugger
Copy link
Collaborator

sgugger commented May 10, 2022

There can be an added flag in the TrainingArguments and we can call this function instead of set_seed in the Trainer. Otherwise it will be for the users to use this one instead of set_seed in their own scripts (you should make it accessible in the main init by the way!)

@hasansalimkanmaz
Copy link
Contributor Author

@sgugger I think I have addressed all your comments. Is there anything that needs to be done for this PR?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just one last comment on the doc!

@hasansalimkanmaz
Copy link
Contributor Author

Is it normal that 3 tests fail suddenly after a commit in a docstring? I couldn't understand why tests are failing.

@sgugger
Copy link
Collaborator

sgugger commented May 11, 2022

Those are just flaky, no link to your PR. Thanks again for all your work on this!

@sgugger sgugger merged commit c33f604 into huggingface:main May 11, 2022
@hasansalimkanmaz hasansalimkanmaz deleted the enable-reproducibility-during-distributed-training branch May 11, 2022 14:34
ArthurZucker pushed a commit to ArthurZucker/transformers that referenced this pull request May 12, 2022
…6907)

* add seed worker and set_deterministic_seed_for_cuda function to enforce reproducability

* change function name to enable determinism, add docstrings, reproducability support for tf

* change function name to enable_determinism_for_distributed_training

* revert changes in set_seed and call set_seed within enable_full_determinism

* add one position argument for seed_worker function

* add full_determinism flag in training args and call enable_full_determinism when it is true

* add enable_full_determinism to documentation

* apply make fixup after the last commit

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <[email protected]>

Co-authored-by: Sylvain Gugger <[email protected]>
@alexcoca
Copy link

alexcoca commented May 18, 2022

@sgugger @hasansalimkanmaz I had a question about this PR - why is it necessary to set CUDA_LAUNCH_BLOCKING? This disables asynchronous execution of CUDA programs, but the cuda/pytorch docs don't mention it necessary for deterministic training? I do use it to get the "true" stack trace when there are device-side asserts but was wondering what role it plays in deterministic training. Many thanks!

@saattrupdan
Copy link
Contributor

@alexcoca It's required to make some CUDA algorithms deterministic if the CUDA version is older than 10.2. I suppose it could be replaced by a CUDA version check somehow, and only using it if it's an old version?

@alexcoca
Copy link

@saattrupdan I would go for this approach, because running the CUDA programs in asynchronous mode will definitely slow things down beyond belief. I implemented this PR myself without the CUDA_LAUNCH_BLOCKING setting and will report if I manage to preserve determinism.

@alexcoca
Copy link

alexcoca commented May 20, 2022

I experimented with training a dialogue state tracking model on the SGD corpus starting from Google's v1.1 T5 (220M) paramaters. I allowed the model to train for roughly two epochs and evaluated task oriented performance every 2k steps (max train steps was 12k).

Ran 4 experiments: 2 in which I set the seed, and an additional 2 where I do roughly the same as ensure_determinism except setting CUDA_LAUNCH_BLOCKING. I also set CUBLAS_WORKSPACE_CONFIG=':4096:8'. Each experiment was trained on 2 A100-80GB with cuda/11.4 openmpi/4.1.1/gcc-9.4.0-epagguv, pytorch 1.10 and transformers 4.19.2. You can see below that I was able to reproduce the metrics in all runs and with no major performance hits. I guess that convolution benchmarking and non-det ops are less relevant for T5. With 4.18.0 the performance was wreaking havoc on the same seed, sign that the data ordering was the culprit.

image
image

I guess the moral of the story here is that one could:

@sgugger ?

@sgugger
Copy link
Collaborator

sgugger commented May 20, 2022

Agreed for the first one. For the second one, we could avoid overriding an existing CUBLAS_WORKSPACE_CONFIG if it's already in the env? In all cases, it should be clearly stated in the doc of the flag that triggers the full reproducibility that it comes at a performance price.

@alexcoca
Copy link

Yes, I agree with the above! I'm at ACL next week but I'll try and open a small PR to address this the week after!

@hasansalimkanmaz
Copy link
Contributor Author

Thanks, @alexcoca for noticing this and for your time.

elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
…6907)

* add seed worker and set_deterministic_seed_for_cuda function to enforce reproducability

* change function name to enable determinism, add docstrings, reproducability support for tf

* change function name to enable_determinism_for_distributed_training

* revert changes in set_seed and call set_seed within enable_full_determinism

* add one position argument for seed_worker function

* add full_determinism flag in training args and call enable_full_determinism when it is true

* add enable_full_determinism to documentation

* apply make fixup after the last commit

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <[email protected]>

Co-authored-by: Sylvain Gugger <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants