diff --git a/train_ms.py b/train_ms.py index 8a6fcab24..2aa93852b 100644 --- a/train_ms.py +++ b/train_ms.py @@ -40,6 +40,7 @@ True ) # Not available if torch version is lower than 2.0 torch.backends.cuda.enable_math_sdp(True) +torch.multiprocessing.set_start_method('spawn') global_step = 0