Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions docs/source/guides/execution.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,23 @@ def your_skypilot_executor(nodes: int, devices: int, container_image: str):
return SkypilotExecutor(
gpus="RTX5880-ADA-GENERATION",
gpus_per_node=devices,
nodes = nodes
env_vars=common_envs()
num_nodes = nodes,
env_vars=common_envs(),
container_image=container_image,
cloud="kubernetes",
infra="k8s/mycontext",
# Optional to reuse Skypilot cluster
cluster_name="tester",
volumes={
"nemo-workspace": "nemo-workspace"
},
volume_mounts=[
{
"path": "/data",
"volume_name": "nemo-workspace",
"size": "50Gi",
"type": "k8s-pvc"
}
],
setup="""
conda deactivate
nvidia-smi
Expand Down
101 changes: 100 additions & 1 deletion nemo_run/core/execution/skypilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union

from invoke.context import Context

Expand All @@ -27,6 +27,7 @@
Executor,
ExecutorMacros,
)
from nemo_run.core.execution.launcher import FaultTolerance, Torchrun
from nemo_run.core.packaging.base import Packager
from nemo_run.core.packaging.git import GitArchivePackager

Expand All @@ -36,6 +37,8 @@
import sky.task as skyt
from sky import backends
from sky.utils import status_lib
from sky.volumes import volume as volume_lib
from sky import models

_SKYPILOT_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -94,6 +97,8 @@ class SkypilotExecutor(Executor):
memory: Optional[Union[int | float, list[int | float]]] = None
instance_type: Optional[Union[str, list[str]]] = None
num_nodes: int = 1
volumes: Optional[Dict[str, str]] = None
volume_mounts: Optional[List[Any]] = None
use_spot: Optional[Union[bool, list[bool]]] = None
disk_size: Optional[Union[int, list[int]]] = None
disk_tier: Optional[Union[str, list[str]]] = None
Expand Down Expand Up @@ -341,6 +346,73 @@ def macro_values(self) -> Optional[ExecutorMacros]:
het_group_host_var=self.HET_GROUP_HOST_VAR,
)

def _setup_launcher(self):
# Auto-enable torchrun for distributed training scenarios:
# 1. Multi-node training (num_nodes > 1)
# 2. Single-node multi-GPU training (gpus_per_node > 1)
if self.launcher is None and (
self.num_nodes > 1 or (self.gpus_per_node and self.gpus_per_node > 1)
):
self.launcher = "torchrun"

super()._setup_launcher()
launcher = self.launcher
# Dynamic rendezvous has an error in Skypilot Kubernetes currently
if (
launcher
and isinstance(launcher, (Torchrun, FaultTolerance))
and self.cloud == "kubernetes"
):
launcher.rdzv_backend = "static"
launcher.rdzv_port = 49500
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need this part anymore, its fixed in the latest version of Skypilot.


def supports_launcher_transform(self) -> bool:
return True

def _parse_infra_for_volume_config(self) -> dict:
"""Parse infra string and return volume config parameters."""
config = {}

if self.infra is not None:
# Parse infra string to extract cloud, region, zone components
# Format: cloud, cloud/region, cloud/region/zone, k8s/context
infra_parts = self.infra.split("/")
cloud = infra_parts[0] if infra_parts else None

if cloud:
# Special handling for Kubernetes
if cloud == "k8s":
# VolumeConfig region and zone required even though they are marked as optional
# validation fails otherwise
config["cloud"] = "kubernetes"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this value needs to be kubernetes - based on the providers supported list in skypilot

config["region"] = "kubernetes"
config["zone"] = "kubernetes"
else:
# Handle regular cloud providers
config["cloud"] = cloud

# Handle region for non-k8s clouds
if len(infra_parts) >= 2:
region = infra_parts[1]
if region and region != "*": # Skip wildcards
config["region"] = region

# Handle zone for non-k8s clouds
if len(infra_parts) >= 3:
zone = infra_parts[2]
if zone and zone != "*": # Skip wildcards
config["zone"] = zone
else:
# Fall back to individual cloud, region, zone parameters
if self.cloud:
config["cloud"] = self.cloud
if self.region:
config["region"] = self.region
if self.zone:
config["zone"] = self.zone

return config

def to_task(
self,
name: str,
Expand All @@ -364,16 +436,43 @@ def to_task(

{" ".join(cmd)}
"""

task = Task(
name=name,
setup=self.setup if self.setup else "",
run=run_cmd,
envs=self.env_vars,
num_nodes=self.num_nodes,
volumes=self.volumes,
)

file_mounts = self.file_mounts or {}
file_mounts["/nemo_run"] = self.job_dir
task.set_file_mounts(file_mounts)
task.set_volumes(self.volumes)

volume_mounts = []
if self.volume_mounts:
for volume_mount in self.volume_mounts:
# Configure volume based on infra if specified, otherwise use cloud/region/zone
volume_config_kwargs = {
"name": volume_mount["volume_name"],
"type": volume_mount["type"],
"name_on_cloud": volume_mount["volume_name"],
"size": volume_mount["size"],
}

# Add parsed infra configuration
volume_config_kwargs.update(self._parse_infra_for_volume_config())

volume_mounts.append(
volume_lib.VolumeMount(
path=volume_mount["path"],
volume_name=volume_mount["volume_name"],
volume_config=models.VolumeConfig(**volume_config_kwargs),
)
)
task.volume_mounts = volume_mounts
task.set_resources(self.to_resources())

if env_vars:
Expand Down
26 changes: 26 additions & 0 deletions test/core/execution/test_skypilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,29 @@ def test_to_task(self, mock_task, mock_skypilot_imports, executor):

# Verify the returned task is our mock
assert result == mock_task_instance

def test_parse_infra_for_volume_config(self, mock_skypilot_imports):
"""Test the _parse_infra_for_volume_config helper method."""

# Test k8s infra
executor_k8s = SkypilotExecutor(infra="k8s/my-context")
config = executor_k8s._parse_infra_for_volume_config()
assert config["cloud"] == "kubernetes"
assert config["region"] == "kubernetes"
assert config["zone"] == "kubernetes"

# Test AWS infra with region and zone
executor_aws = SkypilotExecutor(infra="aws/us-east-1/us-east-1a")
config = executor_aws._parse_infra_for_volume_config()
assert config["cloud"] == "aws"
assert config["region"] == "us-east-1"
assert config["zone"] == "us-east-1a"

# Test fallback to individual parameters
executor_fallback = SkypilotExecutor(
cloud="gcp", region="us-central1", zone="us-central1-a"
)
config = executor_fallback._parse_infra_for_volume_config()
assert config["cloud"] == "gcp"
assert config["region"] == "us-central1"
assert config["zone"] == "us-central1-a"
Loading