diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 8a528c0727..98f0351475 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -386,9 +386,13 @@ def __init__( if mp_ctx is None or isinstance(mp_ctx, str): # Closes issue https://github.com/pymc-devs/pymc/issues/3849 # Related issue https://github.com/pymc-devs/pymc/issues/5339 - if platform.system() == "Darwin": + if mp_ctx is None and platform.system() == "Darwin": if platform.processor() == "arm": mp_ctx = "fork" + logger.debug( + "mp_ctx is set to 'fork' for MacOS with ARM architecture. " + + "This might cause unexpected behavior with JAX, which is inherently multithreaded." + ) else: mp_ctx = "forkserver"