diff --git a/torchtitan/experiments/rl/grpo.py b/torchtitan/experiments/rl/grpo.py index 68e7bf2e44..33404cbb34 100644 --- a/torchtitan/experiments/rl/grpo.py +++ b/torchtitan/experiments/rl/grpo.py @@ -124,6 +124,8 @@ def allocate(self, num_gpus: int) -> Callable[[], None]: def _bootstrap(): os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpu_ids) + # TODO: Remove once Monarch/PyTorch fixes concurrent import during unpickling. + import torch # noqa: F401 return _bootstrap