Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def build_dataloader(
num_workers=0 if cfg.generator.inference_engine.enable_http_endpoint else 8,
drop_last=True if is_train else False,
generator=seeded_generator,
multiprocessing_context="spawn" if not cfg.generator.inference_engine.enable_http_endpoint else None,
)
if is_train:
if not is_fully_async:
Expand Down
9 changes: 6 additions & 3 deletions skyrl/train/entrypoints/main_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import asyncio
import multiprocessing as mp
import os
import sys
from pathlib import Path
Expand Down Expand Up @@ -34,10 +35,12 @@
initialize_ray,
)
from skyrl.utils.tok import get_tokenizer
from skyrl.utils.worker_setup import worker_setup_fn

# Run setup function to ensure driver process has consistent setup as Ray workers
worker_setup_fn()
# NOTE (sumanthrh): We use ray heavily and thus disable `fork` start method.
# forking within ray leads to undefined behaviour and often causes hard to debug
# memory leaks. See: https://docs.ray.io/en/latest/ray-core/patterns/fork-new-processes.html
# A common culprit is Pytorch dataloaders which use `fork` by default.
mp.set_start_method("spawn", force=True)

config_dir = str(Path(__file__).parent.parent / "config")
__all__ = ["BasePPOExp", "config_dir"]
Expand Down
4 changes: 4 additions & 0 deletions skyrl/train/utils/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,10 @@ def build_dataloader(
num_workers=0 if cfg.generator.inference_engine.enable_http_endpoint else 8,
drop_last=True if is_train else False,
generator=seeded_generator,
# NOTE (sumanthrh): We use ray and thus use `spawn` start method.
# forking within ray leads to undefined behaviour and often causes hard to debug
# memory leaks. See: https://docs.ray.io/en/latest/ray-core/patterns/fork-new-processes.html
multiprocessing_context="spawn" if not cfg.generator.inference_engine.enable_http_endpoint else None,
)
if is_train:
if not is_fully_async:
Expand Down
8 changes: 1 addition & 7 deletions skyrl/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,13 +747,7 @@ def initialize_ray(cfg: SkyRLTrainConfig):

# log_to_driver=True allows training progress from skyrl_entrypoint to reach stdout.
# Infrastructure logs (vLLM, workers) are redirected to log file via os.dup2 in their init.
ray.init(
runtime_env={
"env_vars": env_vars,
"worker_process_setup_hook": "skyrl.utils.worker_setup.worker_setup_fn",
},
log_to_driver=True,
)
ray.init(runtime_env={"env_vars": env_vars}, log_to_driver=True)

if not verbose_logging:
logger.info(f"Infrastructure logs will be written to: {log_file}")
Expand Down
19 changes: 0 additions & 19 deletions skyrl/utils/worker_setup.py

This file was deleted.

46 changes: 0 additions & 46 deletions tests/backends/skyrl_train/utils/test_worker_setup.py

This file was deleted.

Loading