Skip to content

Conversation

@liulehui
Copy link
Contributor

@liulehui liulehui commented Oct 16, 2025

Description

  1. This PR added the jax.distributed.shutdown() for JaxBackend in order to free up any leaked resources on TPU RayTrainWorkers.
  2. if jax.distributed is not on, it is a noop: https://docs.jax.dev/en/latest/_autosummary/jax.distributed.shutdown.html
  3. Tested on Anyscale workspace.
image

Related issues

Types of change

  • Bug fix 🐛
  • New feature ✨
  • Enhancement 🚀
  • Code refactoring 🔧
  • Documentation update 📖
  • Chore 🧹
  • Style 🎨

Checklist

Does this PR introduce breaking changes?

  • Yes ⚠️
  • No

Testing:

  • Added/updated tests for my changes
  • Tested the changes manually
  • This PR is not tested ❌ (please explain why)

Code Quality:

  • Signed off every commit (git commit -s)
  • Ran pre-commit hooks (setup guide)

Documentation:

  • Updated documentation (if applicable) (contribution guide)
  • Added new APIs to doc/source/ (if applicable)

Additional context

@liulehui liulehui added the go add ONLY when ready to merge, run all tests label Oct 16, 2025
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way we could add a unit test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried a few way including using caplog and try to check jax.process_count()..
but seems Jax does not well support the cpu distributed env, I am adding a release test for it, would it be ok to check there: #57815

Signed-off-by: Lehui Liu <[email protected]>
Signed-off-by: Lehui Liu <[email protected]>
Signed-off-by: Lehui Liu <[email protected]>
@liulehui liulehui marked this pull request as ready for review October 17, 2025 22:56
@liulehui liulehui requested a review from a team as a code owner October 17, 2025 22:56
@ray-gardener ray-gardener bot added the train Ray Train Related Issue label Oct 18, 2025
@matthewdeng matthewdeng merged commit bf206b2 into ray-project:master Oct 20, 2025
6 checks passed
xinyuangui2 pushed a commit to xinyuangui2/ray that referenced this pull request Oct 22, 2025
…ay-project#57802)

## Description

1. This PR added the `jax.distributed.shutdown()` for JaxBackend in
order to free up any leaked resources on TPU RayTrainWorkers.
2. if `jax.distributed` is not on, it is a noop:
https://docs.jax.dev/en/latest/_autosummary/jax.distributed.shutdown.html
3. Tested on Anyscale workspace.
<img width="1264" height="62" alt="image"
src="https://github.com/user-attachments/assets/f28102ff-f6d1-4da0-b41a-6cc785603e72"
/>

Signed-off-by: xgui <[email protected]>
elliot-barn pushed a commit that referenced this pull request Oct 23, 2025
…57802)

## Description

1. This PR added the `jax.distributed.shutdown()` for JaxBackend in
order to free up any leaked resources on TPU RayTrainWorkers.
2. if `jax.distributed` is not on, it is a noop:
https://docs.jax.dev/en/latest/_autosummary/jax.distributed.shutdown.html
3. Tested on Anyscale workspace.
<img width="1264" height="62" alt="image"
src="https://github.com/user-attachments/assets/f28102ff-f6d1-4da0-b41a-6cc785603e72"
/>

Signed-off-by: elliot-barn <[email protected]>
landscapepainter pushed a commit to landscapepainter/ray that referenced this pull request Nov 17, 2025
…ay-project#57802)

## Description

1. This PR added the `jax.distributed.shutdown()` for JaxBackend in
order to free up any leaked resources on TPU RayTrainWorkers.
2. if `jax.distributed` is not on, it is a noop:
https://docs.jax.dev/en/latest/_autosummary/jax.distributed.shutdown.html
3. Tested on Anyscale workspace.
<img width="1264" height="62" alt="image"
src="https://github.com/user-attachments/assets/f28102ff-f6d1-4da0-b41a-6cc785603e72"
/>
Aydin-ab pushed a commit to Aydin-ab/ray-aydin that referenced this pull request Nov 19, 2025
…ay-project#57802)

## Description

1. This PR added the `jax.distributed.shutdown()` for JaxBackend in
order to free up any leaked resources on TPU RayTrainWorkers.
2. if `jax.distributed` is not on, it is a noop:
https://docs.jax.dev/en/latest/_autosummary/jax.distributed.shutdown.html
3. Tested on Anyscale workspace.
<img width="1264" height="62" alt="image"
src="https://github.com/user-attachments/assets/f28102ff-f6d1-4da0-b41a-6cc785603e72"
/>

Signed-off-by: Aydin Abiar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants