Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions python/monarch/_src/actor/torchrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""
Monarch-based torchrun replacement for distributed PyTorch training.

This module provides a torchrun-compatible launcher that uses Monarch's actor
system instead of torch.distributed.run for process management and coordination.
"""

import argparse
import asyncio
import os
import runpy
import sys

from monarch._src.actor.actor_mesh import Actor, current_rank
from monarch._src.actor.endpoint import endpoint
from monarch._src.actor.host_mesh import this_host


class TorchRunActor(Actor):
"""
Actor that sets up PyTorch distributed training environment variables and
executes training scripts.

This actor replicates torchrun's behavior by configuring RANK, WORLD_SIZE,
LOCAL_RANK, MASTER_ADDR, and MASTER_PORT environment variables before
launching the training script or module.
"""

def __init__(
self,
node_idx: int,
nproc_per_node: int,
world_size: int,
master_addr: str,
master_port: str,
) -> None:
"""
Initialize the TorchRunActor.

Args:
node_idx: Index of this node in the multi-node setup (0-indexed).
nproc_per_node: Number of processes per node (typically matches GPU count).
world_size: Total number of processes across all nodes.
master_addr: Address of the master node for rendezvous.
master_port: Port on the master node for rendezvous.
"""
super().__init__()
self.node_idx = node_idx
self.nproc_per_node = nproc_per_node
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port

@endpoint
def main(self, script_args: list[str]) -> bool:
"""
Set up distributed training environment and execute the training script.

Args:
script_args: Arguments for the training script. First element is either
"-m" (for module execution) or the script path, followed by script arguments.

Returns:
True on successful execution.

Raises:
ValueError: If no script or module is specified.
"""
local_rank = current_rank().rank
rank = local_rank + self.nproc_per_node * self.node_idx

os.environ.update(
{
"RANK": str(rank),
"WORLD_SIZE": str(self.world_size),
"LOCAL_RANK": str(local_rank),
"MASTER_ADDR": self.master_addr,
"MASTER_PORT": self.master_port,
}
)

if script_args and script_args[0] == "-m":
module_name = script_args[1]
sys.argv = [module_name] + list(script_args[2:])
runpy.run_module(module_name, run_name="__main__", alter_sys=True)
elif script_args:
script_path = script_args[0]
sys.argv = list(script_args)
runpy.run_path(script_path, run_name="__main__")
else:
raise ValueError("No script or module specified")

return True


async def _run(args: argparse.Namespace) -> None:
"""
Spawn TorchRunActors on a process mesh and execute the training script.

Args:
args: Parsed command-line arguments containing node configuration and script args.
"""
nproc_per_node = int(args.nproc_per_node)
node_idx = int(args.node_idx)
world_size = args.nnodes * nproc_per_node
mesh = this_host().spawn_procs({"gpus": nproc_per_node})
trainers = mesh.spawn(
"spmd_actor",
TorchRunActor,
node_idx=node_idx,
nproc_per_node=nproc_per_node,
world_size=world_size,
master_addr=args.master_addr,
master_port=args.master_port,
)
script_args = args.script_args or []
await trainers.main.call(script_args)


def main() -> None:
"""
Entry point for monarch torchrun launcher.

Parses command-line arguments and launches distributed training using Monarch actors.
Compatible with torchrun's command-line interface for seamless migration.
"""
parser = argparse.ArgumentParser(
description="Monarch-based torchrun replacement for distributed PyTorch training"
)
parser.add_argument(
"--nproc_per_node",
type=int,
default=2,
help="Number of processes per node (typically matches GPU count)",
)
parser.add_argument("--nnodes", type=int, default=1, help="Total number of nodes")
parser.add_argument(
"--node_idx", type=int, default=0, help="Index of this node (0-indexed)"
)
parser.add_argument(
"--master_addr", default="localhost", help="Address of the master node"
)
parser.add_argument(
"--master_port", default="29500", help="Port on the master node"
)
parser.add_argument(
"script_args",
nargs=argparse.REMAINDER,
help="Script/module and arguments to execute",
)
args = parser.parse_args()

asyncio.run(_run(args))


if __name__ == "__main__":
main()
19 changes: 19 additions & 0 deletions python/monarch/actor/torchrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""
Public interface for monarch torchrun launcher.

This module provides a command-line entry point compatible with torchrun.
Use `python -m monarch.actor.torchrun` to launch distributed training.
"""

from monarch._src.actor.torchrun import main

if __name__ == "__main__":
main()
Loading