Skip to content

Commit 58d61d6

Browse files
committed
[Local Tensor] Replace dry_run.py with local tensor mode implementation
Replaces `dry_run.py` implementation with local tensor mode for DRY_RUN configuration validation. Local tensor mode provides deeper validation coverage, including `ParallelDims` creation, which the previous implementation could not verify. **Note:** Currently returns early before `init_weights()` due to a known limitation in local tensor mode. This still validates more of the pipeline than the previous approach. ghstack-source-id: 3e00d18 Pull-Request: #2057
1 parent ad9f188 commit 58d61d6

File tree

5 files changed

+74
-169
lines changed

5 files changed

+74
-169
lines changed

run_train.sh

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,21 @@ set -ex
1010
# use envs as local overwrites for convenience
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
13-
# DRY_RUN=1 ./run_train.sh # for config validation without GPU
13+
# COMM_MODE="fake_backend" ./run_train.sh # for config validation without GPU
14+
# COMM_MODE="local_tensor" ./run_train.sh # for local tensor debugging mode
1415
NGPU=${NGPU:-"8"}
1516
export LOG_RANK=${LOG_RANK:-0}
1617
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
1718
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}
18-
DRY_RUN=${DRY_RUN:-0}
19+
# COMM_MODE options: "fake_backend" (dry run), "local_tensor" (debug mode), or empty for normal training
20+
COMM_MODE=${COMM_MODE:-""}
1921

2022
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
2123

22-
if [ "$DRY_RUN" = "1" ]; then
23-
# Dry run mode: validate configuration without GPU/distributed setup
24-
echo "Running in DRY RUN mode - configuration validation only"
25-
python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@"
24+
if [ -n "$COMM_MODE" ]; then
25+
# Communication mode specified: validate configuration or run in debug mode
26+
echo "Running with comm_mode=${COMM_MODE}"
27+
NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.mode=${COMM_MODE} --training.steps=1
2628
else
2729
# Normal training with torchrun
2830
PYTORCH_ALLOC_CONF="expandable_segments:True" \

scripts/dry_run.py

Lines changed: 0 additions & 159 deletions
This file was deleted.

torchtitan/config/job_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,22 @@ class Comm:
791791
save_traces_file_prefix: str = "rank_"
792792
"""Flight recorder trace files prefix"""
793793

794+
mode: Literal["default", "fake_backend", "local_tensor"] = "default"
795+
"""
796+
Communication mode for distributed training.
797+
798+
Options:
799+
- "default": Normal distributed training with real communication
800+
- "fake_backend": Fake comm backend for dry run mode only (configuration validation without GPU)
801+
- "local_tensor": Local tensor mode for debugging purposes. There will be only one process
802+
regardless of the number of GPUs. LocalTensor will simulate the computation by running one
803+
rank after another. While the performance will be slow, the numerics should be the same.
804+
This enables us to verify numerics with fewer GPUs. For example, we can directly run 5D
805+
parallelisms within a single node to reduce the combinations we need to use in integration tests.
806+
807+
NOTE: local_tensor is an experimental feature and automatically uses fake_backend internally.
808+
"""
809+
794810

795811
@dataclass
796812
class MemoryEstimation:

torchtitan/distributed/utils.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,51 @@ def maybe_enable_amp(
258258
)
259259

260260

261+
def init_fake_mode(world_size: int, comm_mode: str = "fake_backend"):
262+
"""Initialize fake backend
263+
264+
Args:
265+
world_size: The number of GPUs to simulate
266+
comm_mode: Communication mode ("fake_backend" or "local_tensor")
267+
268+
Returns:
269+
The world size
270+
"""
271+
torch.distributed.init_process_group(
272+
"fake",
273+
rank=0,
274+
world_size=world_size,
275+
)
276+
277+
# If local_tensor mode is enabled, initialize LocalTensorMode context
278+
if comm_mode == "local_tensor":
279+
from torch.distributed import _local_tensor
280+
281+
lm = _local_tensor.LocalTensorMode(world_size)
282+
lm.__enter__()
283+
284+
261285
def init_distributed(
262286
comm_config: CommConfig,
263287
enable_cpu_backend: bool = False,
264288
base_folder: str = "",
265289
ranks: list[int] | None = None,
266-
):
290+
) -> int:
291+
if comm_config.mode in ("fake_backend", "local_tensor"):
292+
ngpu_str = os.environ.get("NGPU")
293+
if ngpu_str is None:
294+
raise ValueError(
295+
f"NGPU environment variable must be set when using comm_mode={comm_config.mode}"
296+
)
297+
try:
298+
world_size = int(ngpu_str)
299+
except ValueError as e:
300+
raise ValueError(
301+
f"NGPU environment variable must be a valid integer, got: {ngpu_str}"
302+
) from e
303+
init_fake_mode(world_size, comm_config.mode)
304+
return world_size
305+
267306
def _warn_overwrite_env(env, val):
268307
if env in os.environ:
269308
logger.warning(
@@ -309,6 +348,8 @@ def _get_distributed_backend(enable_cpu_backend):
309348
_ranks=ranks if ranks is not None else [],
310349
)
311350

351+
return torch.distributed.get_world_size()
352+
312353

313354
def set_pg_timeouts(timeout, world_mesh):
314355
"""

torchtitan/train.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,13 @@ def __init__(self, job_config: JobConfig):
360360

361361
def init_distributed(self) -> ParallelDims:
362362
job_config = self.job_config
363-
dist_utils.init_distributed(
363+
world_size = dist_utils.init_distributed(
364364
job_config.comm,
365365
enable_cpu_backend=job_config.training.enable_cpu_offload,
366366
base_folder=job_config.job.dump_folder,
367367
)
368368

369-
world_size = int(os.environ["WORLD_SIZE"])
370369
parallelism_config = job_config.parallelism
371-
372370
return ParallelDims(
373371
dp_shard=parallelism_config.data_parallel_shard_degree,
374372
dp_replicate=parallelism_config.data_parallel_replicate_degree,
@@ -725,6 +723,13 @@ def main(trainer_class: type[Trainer]) -> None:
725723
try:
726724
trainer = trainer_class(config)
727725

726+
# TODO(local_tensor): Remove this special case once LocalTensor supports
727+
# init_weights() and foreach_allgather. In local tensor mode, skip
728+
# training/checkpointing as the # model is not fully initialized
729+
if config.comm.mode == "local_tensor":
730+
logger.info("Local tensor mode enabled - skipping training execution")
731+
return
732+
728733
if config.checkpoint.create_seed_checkpoint:
729734
assert (
730735
int(os.environ["WORLD_SIZE"]) == 1

0 commit comments

Comments
 (0)