Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions torchtitan/experiments/torchcomms/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.distributed import ParallelDims
import os

from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.train import main, Trainer

from .parallel_dims import TorchCommsParallelDims
Expand All @@ -13,7 +16,16 @@
class TorchCommsTrainer(Trainer):
parallel_dims: TorchCommsParallelDims

def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
def init_distributed_env(self, job_config: JobConfig) -> ParallelDims:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR doesn't do anything new with distributed backend. Are you saying we will modify this code for torchcomms?

Asking because I'm a bit concerned about sending whole job_config to this function, which seems a regression in terms of config usage hygiene.

dist_utils.init_distributed(
job_config.comm,
enable_cpu_backend=job_config.training.enable_cpu_offload,
base_folder=job_config.job.dump_folder,
)

world_size = int(os.environ["WORLD_SIZE"])
parallelism_config = job_config.parallelism

return TorchCommsParallelDims(
dp_shard=parallelism_config.data_parallel_shard_degree,
dp_replicate=parallelism_config.data_parallel_replicate_degree,
Expand Down
28 changes: 14 additions & 14 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,10 @@ def __init__(self, job_config: JobConfig):
# Device has to be set before creating TorchFT manager.
device_module.set_device(self.device)

# init distributed and build meshes
dist_utils.init_distributed(
job_config.comm,
enable_cpu_backend=job_config.training.enable_cpu_offload,
base_folder=job_config.job.dump_folder,
)

job_config.maybe_log()

world_size = int(os.environ["WORLD_SIZE"])
parallelism_config = job_config.parallelism
self.parallel_dims = parallel_dims = self._create_parallel_dims(
parallelism_config, world_size
)
# init distributed and build meshes
self.parallel_dims = parallel_dims = self.init_distributed_env(job_config)

world_mesh = parallel_dims.world_mesh
if parallel_dims.dp_enabled:
Expand Down Expand Up @@ -319,7 +309,8 @@ def __init__(self, job_config: JobConfig):
)

loss_parallel_enabled = (
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
parallel_dims.tp_enabled
and not job_config.parallelism.disable_loss_parallel
)
self.train_context = dist_utils.get_train_context(loss_parallel_enabled)
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
Expand Down Expand Up @@ -367,7 +358,16 @@ def __init__(self, job_config: JobConfig):
f"(warmup {job_config.lr_scheduler.warmup_steps})"
)

def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
def init_distributed_env(self, job_config: JobConfig) -> ParallelDims:
dist_utils.init_distributed(
job_config.comm,
enable_cpu_backend=job_config.training.enable_cpu_offload,
base_folder=job_config.job.dump_folder,
)

world_size = int(os.environ["WORLD_SIZE"])
parallelism_config = job_config.parallelism

return ParallelDims(
dp_shard=parallelism_config.data_parallel_shard_degree,
dp_replicate=parallelism_config.data_parallel_replicate_degree,
Expand Down
Loading