Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 67 additions & 4 deletions nemo_rl/utils/prefetch_venvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
from pathlib import Path
Expand All @@ -21,16 +22,35 @@
from nemo_rl.utils.venvs import create_local_venv


def prefetch_venvs():
"""Prefetch all virtual environments that will be used by workers."""
def prefetch_venvs(filters=None):
"""Prefetch all virtual environments that will be used by workers.

Args:
filters: List of strings to match against actor FQNs. If provided, only
actors whose FQN contains at least one of the filter strings will
be prefetched. If None, all venvs are prefetched.
"""
print("Prefetching virtual environments...")
if filters:
print(f"Filtering for: {filters}")

# Track statistics for summary
skipped_by_filter = []
skipped_system_python = []
prefetched = []
failed = []

# Group venvs by py_executable to avoid duplicating work
venv_configs = {}
for actor_fqn, py_executable in ACTOR_ENVIRONMENT_REGISTRY.items():
# Apply filters if provided
if filters and not any(f in actor_fqn for f in filters):
skipped_by_filter.append(actor_fqn)
continue
# Skip system python as it doesn't need a venv
if py_executable == "python" or py_executable == sys.executable:
print(f"Skipping {actor_fqn} (uses system Python)")
skipped_system_python.append(actor_fqn)
continue

# Only create venvs for uv-based executables
Expand All @@ -47,12 +67,31 @@ def prefetch_venvs():
try:
python_path = create_local_venv(py_executable, actor_fqn)
print(f" Success: {python_path}")
prefetched.append(actor_fqn)
except Exception as e:
print(f" Error: {e}")
failed.append(actor_fqn)
# Continue with other venvs even if one fails
continue

print("\nVenv prefetching complete!")
# Print summary
print("\n" + "=" * 50)
print("Venv prefetching complete! Summary:")
print("=" * 50)
print(f" Prefetched: {len(prefetched)}")
for actor_fqn in prefetched:
print(f" - {actor_fqn}")
print(f" Skipped (system Python): {len(skipped_system_python)}")
for actor_fqn in skipped_system_python:
print(f" - {actor_fqn}")
if filters:
print(f" Skipped (filtered out): {len(skipped_by_filter)}")
for actor_fqn in skipped_by_filter:
print(f" - {actor_fqn}")
if failed:
print(f" Failed: {len(failed)}")
for actor_fqn in failed:
print(f" - {actor_fqn}")

# Create convenience python wrapper scripts for frozen environment support (container-only)
create_frozen_environment_symlinks(venv_configs)
Expand Down Expand Up @@ -150,4 +189,28 @@ def create_frozen_environment_symlinks(venv_configs):


if __name__ == "__main__":
prefetch_venvs()
parser = argparse.ArgumentParser(
description="Prefetch virtual environments for Ray actors.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Prefetch all venvs
python -m nemo_rl.utils.prefetch_venvs

# Prefetch only vLLM-related venvs
python -m nemo_rl.utils.prefetch_venvs vllm

# Prefetch multiple specific venvs
python -m nemo_rl.utils.prefetch_venvs vllm policy environment
""",
)
parser.add_argument(
"filters",
nargs="*",
help="Filter strings to match against actor FQNs. Only actors whose FQN "
"contains at least one of these strings will be prefetched. "
"If not provided, all venvs are prefetched.",
)
args = parser.parse_args()

prefetch_venvs(filters=args.filters if args.filters else None)
272 changes: 272 additions & 0 deletions tests/unit/utils/test_prefetch_venvs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import patch

import pytest

import nemo_rl.utils.prefetch_venvs as prefetch_venvs_module

# When NRL_CONTAINER is set, create_frozen_environment_symlinks also calls
# create_local_venv for each actor, effectively doubling the call count
CALL_MULTIPLIER = 2 if os.environ.get("NRL_CONTAINER") else 1


@pytest.fixture
def mock_registry():
"""Create a mock registry with various actor types."""
return {
"nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker": "uv run --group vllm",
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": "uv run --group vllm",
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": "uv run --group mcore",
"nemo_rl.environments.math_environment.MathEnvironment": "python",
"nemo_rl.environments.code_environment.CodeEnvironment": "python",
}


@pytest.fixture
def prefetch_venvs_func(mock_registry):
"""Patch the registry directly in the prefetch_venvs module."""
with patch.object(
prefetch_venvs_module, "ACTOR_ENVIRONMENT_REGISTRY", mock_registry
):
yield prefetch_venvs_module.prefetch_venvs


class TestPrefetchVenvs:
"""Tests for the prefetch_venvs function."""

def test_prefetch_venvs_no_filters(self, prefetch_venvs_func):
"""Test that all uv-based venvs are prefetched when no filters are provided."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
mock_create_venv.return_value = "/path/to/venv/bin/python"

prefetch_venvs_func(filters=None)

assert mock_create_venv.call_count == 3 * CALL_MULTIPLIER

# Verify the actors that were called
call_args = [call[0] for call in mock_create_venv.call_args_list]
actor_fqns = [args[1] for args in call_args]

assert (
"nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker"
in actor_fqns
)
assert (
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker"
in actor_fqns
)
assert (
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker"
in actor_fqns
)

def test_prefetch_venvs_single_filter(self, prefetch_venvs_func):
"""Test filtering with a single filter string."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
mock_create_venv.return_value = "/path/to/venv/bin/python"

prefetch_venvs_func(filters=["vllm"])

# Should only create venvs for actors containing "vllm" (1 actor)
assert mock_create_venv.call_count == 1 * CALL_MULTIPLIER

call_args = mock_create_venv.call_args[0]
assert (
call_args[1]
== "nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker"
)

def test_prefetch_venvs_multiple_filters(self, prefetch_venvs_func):
"""Test filtering with multiple filter strings (OR logic)."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
mock_create_venv.return_value = "/path/to/venv/bin/python"

prefetch_venvs_func(filters=["vllm", "megatron"])

# Should create venvs for actors containing "vllm" OR "megatron" (2 actors)
assert mock_create_venv.call_count == 2 * CALL_MULTIPLIER

call_args = [call[0] for call in mock_create_venv.call_args_list]
actor_fqns = [args[1] for args in call_args]

assert (
"nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker"
in actor_fqns
)
assert (
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker"
in actor_fqns
)

def test_prefetch_venvs_filter_no_match(self, prefetch_venvs_func):
"""Test that no venvs are created when filter matches nothing."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
mock_create_venv.return_value = "/path/to/venv/bin/python"

prefetch_venvs_func(filters=["nonexistent"])

# Should not create any venvs
assert mock_create_venv.call_count == 0

def test_prefetch_venvs_skips_system_python(self, prefetch_venvs_func):
"""Test that system python actors are skipped even if they match filters."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
mock_create_venv.return_value = "/path/to/venv/bin/python"

# Filter for "environment" which matches system python actors
prefetch_venvs_func(filters=["environment"])

# Should not create any venvs since matching actors use system python
assert mock_create_venv.call_count == 0

def test_prefetch_venvs_partial_match(self, prefetch_venvs_func):
"""Test that filter matches partial strings within FQN."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
mock_create_venv.return_value = "/path/to/venv/bin/python"

# "policy" should match both dtensor_policy_worker and megatron_policy_worker
prefetch_venvs_func(filters=["policy"])

assert mock_create_venv.call_count == 2 * CALL_MULTIPLIER

call_args = [call[0] for call in mock_create_venv.call_args_list]
actor_fqns = [args[1] for args in call_args]

assert (
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker"
in actor_fqns
)
assert (
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker"
in actor_fqns
)

def test_prefetch_venvs_empty_filter_list(self, prefetch_venvs_func):
"""Test that empty filter list is treated as no filtering (falsy)."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
mock_create_venv.return_value = "/path/to/venv/bin/python"

# Empty list should be falsy and prefetch all
prefetch_venvs_func(filters=[])

# Should create venvs for all uv-based actors (3 total)
assert mock_create_venv.call_count == 3 * CALL_MULTIPLIER

def test_prefetch_venvs_continues_on_error(self, prefetch_venvs_func):
"""Test that prefetching continues even if one venv creation fails."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
# Provide enough return values for both prefetch and frozen env symlinks
mock_create_venv.side_effect = [
Exception("Test error"),
"/path/to/venv/bin/python",
"/path/to/venv/bin/python",
] * CALL_MULTIPLIER

# Should not raise, should continue with other venvs
prefetch_venvs_func(filters=None)

# All 3 uv-based actors should have been attempted
assert mock_create_venv.call_count == 3 * CALL_MULTIPLIER

def test_prefetch_venvs_case_sensitive_filter(self, prefetch_venvs_func):
"""Test that filters are case-sensitive."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
mock_create_venv.return_value = "/path/to/venv/bin/python"

# "VLLM" (uppercase) should not match "vllm" (lowercase)
prefetch_venvs_func(filters=["VLLM"])

assert mock_create_venv.call_count == 0

def test_prefetch_venvs_summary_no_filters(self, prefetch_venvs_func, capsys):
"""Test that summary is printed with correct counts and names when no filters."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
mock_create_venv.return_value = "/path/to/venv/bin/python"

prefetch_venvs_func(filters=None)

captured = capsys.readouterr()
assert "Venv prefetching complete! Summary:" in captured.out
assert "Prefetched: 3" in captured.out
assert "Skipped (system Python): 2" in captured.out
# Verify prefetched env names are listed
assert "VllmGenerationWorker" in captured.out
assert "DTensorPolicyWorker" in captured.out
assert "MegatronPolicyWorker" in captured.out
# Verify skipped env names are listed
assert "MathEnvironment" in captured.out
assert "CodeEnvironment" in captured.out
# "Skipped (filtered out)" should not appear when no filters
assert "Skipped (filtered out)" not in captured.out

def test_prefetch_venvs_summary_with_filters(self, prefetch_venvs_func, capsys):
"""Test that summary includes filtered out names when filters are used."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
mock_create_venv.return_value = "/path/to/venv/bin/python"

prefetch_venvs_func(filters=["vllm"])

captured = capsys.readouterr()
assert "Venv prefetching complete! Summary:" in captured.out
assert "Prefetched: 1" in captured.out
assert "Skipped (system Python): 0" in captured.out
assert "Skipped (filtered out): 4" in captured.out
# Verify prefetched env name is listed
assert "VllmGenerationWorker" in captured.out
# Verify filtered out env names are listed
assert "DTensorPolicyWorker" in captured.out
assert "MegatronPolicyWorker" in captured.out

def test_prefetch_venvs_summary_with_failures(self, prefetch_venvs_func, capsys):
"""Test that summary includes failed actor names when errors occur."""
with patch(
"nemo_rl.utils.prefetch_venvs.create_local_venv"
) as mock_create_venv:
# Provide enough return values for both prefetch and frozen env symlinks
mock_create_venv.side_effect = [
Exception("Test error"),
"/path/to/venv/bin/python",
"/path/to/venv/bin/python",
] * CALL_MULTIPLIER

prefetch_venvs_func(filters=None)

captured = capsys.readouterr()
assert "Venv prefetching complete! Summary:" in captured.out
assert "Prefetched: 2" in captured.out
assert "Failed: 1" in captured.out
Loading