Skip to content

Commit c75f8ca

Browse files
colin2328facebook-github-bot
authored andcommitted
add monarch.torchrun module; replicate torchrun using monarch as a proc manager
Differential Revision: D86152013
1 parent 22ee7b4 commit c75f8ca

File tree

2 files changed

+169
-0
lines changed

2 files changed

+169
-0
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
"""
10+
Monarch-based torchrun replacement for distributed PyTorch training.
11+
12+
This module provides a torchrun-compatible launcher that uses Monarch's actor
13+
system instead of torch.distributed.run for process management and coordination.
14+
"""
15+
16+
import argparse
17+
import asyncio
18+
import os
19+
import runpy
20+
import sys
21+
22+
from monarch._src.actor.actor_mesh import Actor, current_rank
23+
from monarch._src.actor.endpoint import endpoint
24+
from monarch._src.actor.host_mesh import this_host
25+
26+
27+
class TorchRunActor(Actor):
28+
"""
29+
Actor that sets up PyTorch distributed training environment variables and
30+
executes training scripts.
31+
32+
This actor replicates torchrun's behavior by configuring RANK, WORLD_SIZE,
33+
LOCAL_RANK, MASTER_ADDR, and MASTER_PORT environment variables before
34+
launching the training script or module.
35+
"""
36+
37+
def __init__(
38+
self,
39+
node_idx: int,
40+
nproc_per_node: int,
41+
world_size: int,
42+
master_addr: str,
43+
master_port: str,
44+
) -> None:
45+
"""
46+
Initialize the TorchRunActor.
47+
48+
Args:
49+
node_idx: Index of this node in the multi-node setup (0-indexed).
50+
nproc_per_node: Number of processes per node (typically matches GPU count).
51+
world_size: Total number of processes across all nodes.
52+
master_addr: Address of the master node for rendezvous.
53+
master_port: Port on the master node for rendezvous.
54+
"""
55+
super().__init__()
56+
self.node_idx = node_idx
57+
self.nproc_per_node = nproc_per_node
58+
self.world_size = world_size
59+
self.master_addr = master_addr
60+
self.master_port = master_port
61+
62+
@endpoint
63+
def main(self, script_args: list[str]) -> bool:
64+
"""
65+
Set up distributed training environment and execute the training script.
66+
67+
Args:
68+
script_args: Arguments for the training script. First element is either
69+
"-m" (for module execution) or the script path, followed by script arguments.
70+
71+
Returns:
72+
True on successful execution.
73+
74+
Raises:
75+
ValueError: If no script or module is specified.
76+
"""
77+
local_rank = current_rank().rank
78+
rank = local_rank + self.nproc_per_node * self.node_idx
79+
80+
os.environ.update(
81+
{
82+
"RANK": str(rank),
83+
"WORLD_SIZE": str(self.world_size),
84+
"LOCAL_RANK": str(local_rank),
85+
"MASTER_ADDR": self.master_addr,
86+
"MASTER_PORT": self.master_port,
87+
}
88+
)
89+
90+
if script_args and script_args[0] == "-m":
91+
module_name = script_args[1]
92+
sys.argv = [module_name] + list(script_args[2:])
93+
runpy.run_module(module_name, run_name="__main__", alter_sys=True)
94+
elif script_args:
95+
script_path = script_args[0]
96+
sys.argv = list(script_args)
97+
runpy.run_path(script_path, run_name="__main__")
98+
else:
99+
raise ValueError("No script or module specified")
100+
101+
return True
102+
103+
104+
async def _run(args: argparse.Namespace) -> None:
105+
"""
106+
Spawn TorchRunActors on a process mesh and execute the training script.
107+
108+
Args:
109+
args: Parsed command-line arguments containing node configuration and script args.
110+
"""
111+
nproc_per_node = int(args.nproc_per_node)
112+
node_idx = int(args.node_idx)
113+
world_size = args.nnodes * nproc_per_node
114+
mesh = this_host().spawn_procs({"gpus": nproc_per_node})
115+
trainers = mesh.spawn(
116+
"spmd_actor",
117+
TorchRunActor,
118+
node_idx=node_idx,
119+
nproc_per_node=nproc_per_node,
120+
world_size=world_size,
121+
master_addr=args.master_addr,
122+
master_port=args.master_port,
123+
)
124+
script_args = args.script_args or []
125+
await trainers.main.call(script_args)
126+
127+
128+
def main() -> None:
129+
"""
130+
Entry point for monarch torchrun launcher.
131+
132+
Parses command-line arguments and launches distributed training using Monarch actors.
133+
Compatible with torchrun's command-line interface for seamless migration.
134+
"""
135+
parser = argparse.ArgumentParser(
136+
description="Monarch-based torchrun replacement for distributed PyTorch training"
137+
)
138+
parser.add_argument(
139+
"--nproc_per_node",
140+
type=int,
141+
default=2,
142+
help="Number of processes per node (typically matches GPU count)",
143+
)
144+
parser.add_argument(
145+
"--nnodes", type=int, default=1, help="Total number of nodes"
146+
)
147+
parser.add_argument(
148+
"--node_idx", type=int, default=0, help="Index of this node (0-indexed)"
149+
)
150+
parser.add_argument(
151+
"--master_addr", default="localhost", help="Address of the master node"
152+
)
153+
parser.add_argument(
154+
"--master_port", default="29500", help="Port on the master node"
155+
)
156+
parser.add_argument(
157+
"script_args",
158+
nargs=argparse.REMAINDER,
159+
help="Script/module and arguments to execute",
160+
)
161+
args = parser.parse_args()
162+
163+
asyncio.run(_run(args))
164+
165+
166+
if __name__ == "__main__":
167+
main()

python/monarch/actor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from monarch._src.actor.endpoint import endpoint
3939
from monarch._src.actor.future import Future
4040
from monarch._src.actor.supervision import unhandled_fault_hook
41+
from monarch._src.actor import torchrun
4142

4243
from monarch._src.actor.v1 import enabled as v1_enabled
4344

@@ -108,4 +109,5 @@
108109
"ChannelTransport",
109110
"unhandled_fault_hook",
110111
"MeshFailure",
112+
"torchrun",
111113
]

0 commit comments

Comments
 (0)