-
Notifications
You must be signed in to change notification settings - Fork 7k
[train][jax_trainer] add jax.distributed.shutdown() for JaxBackend
#57802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Lehui Liu <[email protected]>
Signed-off-by: Lehui Liu <[email protected]>
Signed-off-by: Lehui Liu <[email protected]>
Signed-off-by: Lehui Liu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any way we could add a unit test for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried a few way including using caplog and try to check jax.process_count()..
but seems Jax does not well support the cpu distributed env, I am adding a release test for it, would it be ok to check there: #57815
Signed-off-by: Lehui Liu <[email protected]>
Signed-off-by: Lehui Liu <[email protected]>
Signed-off-by: Lehui Liu <[email protected]>
…ay-project#57802) ## Description 1. This PR added the `jax.distributed.shutdown()` for JaxBackend in order to free up any leaked resources on TPU RayTrainWorkers. 2. if `jax.distributed` is not on, it is a noop: https://docs.jax.dev/en/latest/_autosummary/jax.distributed.shutdown.html 3. Tested on Anyscale workspace. <img width="1264" height="62" alt="image" src="https://github.com/user-attachments/assets/f28102ff-f6d1-4da0-b41a-6cc785603e72" /> Signed-off-by: xgui <[email protected]>
…57802) ## Description 1. This PR added the `jax.distributed.shutdown()` for JaxBackend in order to free up any leaked resources on TPU RayTrainWorkers. 2. if `jax.distributed` is not on, it is a noop: https://docs.jax.dev/en/latest/_autosummary/jax.distributed.shutdown.html 3. Tested on Anyscale workspace. <img width="1264" height="62" alt="image" src="https://github.com/user-attachments/assets/f28102ff-f6d1-4da0-b41a-6cc785603e72" /> Signed-off-by: elliot-barn <[email protected]>
…ay-project#57802) ## Description 1. This PR added the `jax.distributed.shutdown()` for JaxBackend in order to free up any leaked resources on TPU RayTrainWorkers. 2. if `jax.distributed` is not on, it is a noop: https://docs.jax.dev/en/latest/_autosummary/jax.distributed.shutdown.html 3. Tested on Anyscale workspace. <img width="1264" height="62" alt="image" src="https://github.com/user-attachments/assets/f28102ff-f6d1-4da0-b41a-6cc785603e72" />
…ay-project#57802) ## Description 1. This PR added the `jax.distributed.shutdown()` for JaxBackend in order to free up any leaked resources on TPU RayTrainWorkers. 2. if `jax.distributed` is not on, it is a noop: https://docs.jax.dev/en/latest/_autosummary/jax.distributed.shutdown.html 3. Tested on Anyscale workspace. <img width="1264" height="62" alt="image" src="https://github.com/user-attachments/assets/f28102ff-f6d1-4da0-b41a-6cc785603e72" /> Signed-off-by: Aydin Abiar <[email protected]>
Description
jax.distributed.shutdown()for JaxBackend in order to free up any leaked resources on TPU RayTrainWorkers.jax.distributedis not on, it is a noop: https://docs.jax.dev/en/latest/_autosummary/jax.distributed.shutdown.htmlRelated issues
Types of change
Checklist
Does this PR introduce breaking changes?
Testing:
Code Quality:
git commit -s)Documentation:
doc/source/(if applicable)Additional context