Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multistart Parallelization with Jit Compatible Code #493

Open
nikithiel opened this issue Apr 2, 2024 · 7 comments
Open

Multistart Parallelization with Jit Compatible Code #493

nikithiel opened this issue Apr 2, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@nikithiel
Copy link

Hey there,

I am trying to run a gradient-based algorithm with multistart of my jit compatible code in parallel. Can I use estimagic's parallelisation using 'nprocs' via joblib or pathos or do I need to create a sample for the exploration phase manually and distribute it using jax parallelisation?

When running multistart=True with n_procs=2, I'm encountering the following warning:

RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.

If helpful, I can post a some code snippets from my implementation.

@nikithiel nikithiel added the bug Something isn't working label Apr 2, 2024
@timmens
Copy link
Member

timmens commented Apr 7, 2024

Hey @nikithiel,

Thanks for opening the issue!

Could you show me a minimal reproducible example of the behavior mentioned? And could you also post the versions of estimagic, joblib, and jax/jaxlib you use?

That would be very helpful in solving your issue!

@nikithiel
Copy link
Author

Hey @timmens,

I'm using the following versions:

estimagic=0.4.6
joblib=1.3.2
jax/jaxlib=0.4.26

This warning is occuring when I'm running my code on a linux HPC, so I think it's related to that. I found two very interesting posts:

https://discuss.python.org/t/switching-default-multiprocessing-context-to-spawn-on-posix-as-well/21868/22
jax-ml/jax#18852 --> jax-ml/jax#18989

It seems like in multiprocessing, spawn is the default on Windows and macOS, while on linux it is fork. The latter is incompatible with multithreading (which JAX does all the time).

I'm not sure how to force joblib to use spawn tbh. Maybe by changing the backend argument in the joblib.Parallel call? Maybe this helps:

jax-ml/jax#18852

I could also try to create an MWE. However, this is not so straighforward, as the problem occurs in a large code project.

Hope this helps,
Niklas

@nikithiel nikithiel closed this as not planned Won't fix, can't repro, duplicate, stale Apr 8, 2024
@nikithiel
Copy link
Author

I accidently closed this issue. Sorry!

@nikithiel nikithiel reopened this Apr 8, 2024
@hmgaudecker
Copy link
Member

It's probably just a check in Jax whether fork has been called.

Happens to me in a project with pytask-parallel and Jax recently, too.

@timmens
Copy link
Member

timmens commented Apr 9, 2024

I've found a minimal reproducible example using JAX and joblib and a way to fix it (in the MRE).

As you correctly anticipated @nikithiel , choosing a different parallelization backend fixes the MRE problem. If you want to validate that this fixes your problem, you could use a local estimagic installation to add the backend="threading" argument to the joblib batch evaluator in the batch_evaluators.py module.

Additionally, you can always run the multistart in serial using multistart_options = {"n_cores": 1}, which could already be fast enough since your objective function is multi-threaded.

Note

The following is tested on my Linux ThinkPad and might not work on your HPC machine.

@janosg, I propose we add an option to the batch_evaluator for custom kwargs and allow these to be passed through the multistart_options. What are your thoughts?

Minimal Reproducible Example

import jax.numpy as jnp
from joblib import Parallel, delayed

x_list = [jnp.ones(2) for _ in range(2)]

# Backend: loky (results in a warning)
Parallel(n_jobs=2, backend="loky")(delayed(jnp.mean)(x) for x in x_list)

# Backend: threading (does *not* result in a warning)
Parallel(n_jobs=2, backend="threading")(delayed(jnp.mean)(x) for x in x_list)

Backend: loky

Results in the warning

RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is
multithreaded, so this will likely lead to a deadlock.

Backend: threading

Results in no warning.

Versions

  • estimagic = 0.4.7
  • joblib = 1.3.2
  • jax/jaxlib = 0.4.26

@nikithiel
Copy link
Author

nikithiel commented Apr 12, 2024

Hey @timmens,

thanks for the MRE and the suggested solution. I have changed the argument as suggested and I don't get an error on the HPC machine either. I also compared the performance of my bigger code project for a serial run, a run with backend=loky and a run with backend=threading. Interestingly, loky is 3 times faster than serial and threading is 5 times faster.

@janosg
Copy link
Member

janosg commented Apr 12, 2024

This is a very important usecase for us and we should offer a batch evaluator that support jax functions. It's not just for multistart but also for bootstrap or parallelizing optimizers. Instead of making the batch evaluator configurable with more arguments I would probably just add a new batch evaluator.

In the meantime I see two workarounds:

  1. downgrade jax. I think a year ago or so we did not have these problems
  2. disable parallelization in JAX with something like this.

Disabling JAX's default parallelism is probably a good idea anyways when you do multistart. Running multiple optimizations in parallel is a very simple and efficient form of parallelization. So as long as you have enough optimizations to keep your computer busy you probably don't want parallelize the objective function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants