Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions torchtitan/experiments/torchcomms/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.distributed import ParallelDims
import os

from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.train import main, Trainer

from .parallel_dims import TorchCommsParallelDims
Expand All @@ -13,7 +15,17 @@
class TorchCommsTrainer(Trainer):
parallel_dims: TorchCommsParallelDims

def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
def init_distributed(self) -> ParallelDims:
job_config = self.job_config
dist_utils.init_distributed(
job_config.comm,
enable_cpu_backend=job_config.training.enable_cpu_offload,
base_folder=job_config.job.dump_folder,
)

world_size = int(os.environ["WORLD_SIZE"])
parallelism_config = job_config.parallelism

return TorchCommsParallelDims(
dp_shard=parallelism_config.data_parallel_shard_degree,
dp_replicate=parallelism_config.data_parallel_replicate_degree,
Expand Down
29 changes: 15 additions & 14 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,10 @@ def __init__(self, job_config: JobConfig):
# Device has to be set before creating TorchFT manager.
device_module.set_device(self.device)

# init distributed and build meshes
dist_utils.init_distributed(
job_config.comm,
enable_cpu_backend=job_config.training.enable_cpu_offload,
base_folder=job_config.job.dump_folder,
)

job_config.maybe_log()

world_size = int(os.environ["WORLD_SIZE"])
parallelism_config = job_config.parallelism
self.parallel_dims = parallel_dims = self._create_parallel_dims(
parallelism_config, world_size
)
# init distributed and build meshes
self.parallel_dims = parallel_dims = self.init_distributed()

world_mesh = parallel_dims.world_mesh
if parallel_dims.dp_enabled:
Expand Down Expand Up @@ -319,7 +309,8 @@ def __init__(self, job_config: JobConfig):
)

loss_parallel_enabled = (
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
parallel_dims.tp_enabled
and not job_config.parallelism.disable_loss_parallel
)
self.train_context = dist_utils.get_train_context(loss_parallel_enabled)
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
Expand Down Expand Up @@ -367,7 +358,17 @@ def __init__(self, job_config: JobConfig):
f"(warmup {job_config.lr_scheduler.warmup_steps})"
)

def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
def init_distributed(self) -> ParallelDims:
job_config = self.job_config
dist_utils.init_distributed(
job_config.comm,
enable_cpu_backend=job_config.training.enable_cpu_offload,
base_folder=job_config.job.dump_folder,
)

world_size = int(os.environ["WORLD_SIZE"])
parallelism_config = job_config.parallelism

return ParallelDims(
dp_shard=parallelism_config.data_parallel_shard_degree,
dp_replicate=parallelism_config.data_parallel_replicate_degree,
Expand Down
Loading