diff --git a/run_ft.sh b/run_ft.sh new file mode 100755 index 0000000000..d5aba4c3ec --- /dev/null +++ b/run_ft.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC=0 TORCH_NCCL_DUMP_ON_TIMEOUT=0 TORCH_NCCL_TRACE_BUFFER_SIZE=0 TORCH_SHARE_RDZV_TCP_STORE=1 LOGLEVEL=INFO NCCL_DEBUG_SUBSYS=ALL NCCL_DEBUG=INFO TORCH_CPP_LOG_LEVEL=INFO CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_train.sh --training.local_batch_size=2 --parallelism.data_parallel_shard_degree=2 --profiling.enable_profiling --profiling.profile_freq=1 --profiling.profiler_active=1 --profiling.profiler_warmup=0 --training.steps=1000 --comm.train_timeout_seconds=1 --comm.trace_buf_size=0 diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index d4c5416aa2..79918d0046 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -189,15 +189,25 @@ def __init__( self.enable = checkpoint_config.enable self.load_only = checkpoint_config.load_only + self.states = states + self.states.update( + { + MODEL: ModelWrapper(model_parts), + OPTIMIZER: optimizers, + DATALOADER: dataloader, + LR_SCHEDULER: lr_schedulers, + } + ) + self.ft_manager = ( - ft_manager.manager - if ft_manager - and ft_manager.enabled - and checkpoint_config.enable_ft_dataloader_checkpoints - else None + ft_manager.manager if ft_manager and ft_manager.enabled else None ) - if ft_manager and ft_manager.enabled and not self.ft_manager: + self.enable_ft_dataloader_checkpoints = ( + self.ft_manager and checkpoint_config.enable_ft_dataloader_checkpoints + ) + + if self.ft_manager and not self.enable_ft_dataloader_checkpoints: logger.warn( "Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. " "This means replicas can retrain over the same data multiple times, which can result in overfitting." @@ -229,20 +239,11 @@ def load_state_dict(state_dict): async_mode = checkpoint_config.async_mode.lower() self.enable_staging = ( self.enable and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM - ) or self.ft_manager + ) or self.enable_ft_dataloader_checkpoints - if not self.enable and self.ft_manager is None: + if not self.enable and not self.enable_ft_dataloader_checkpoints: return - self.states = states - self.states.update( - { - MODEL: ModelWrapper(model_parts), - OPTIMIZER: optimizers, - DATALOADER: dataloader, - LR_SCHEDULER: lr_schedulers, - } - ) self.ft_states = {DATALOADER: dataloader} self.staging = False @@ -279,7 +280,7 @@ def load_state_dict(state_dict): if ( async_mode == AsyncMode.ASYNC or async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM - or self.ft_manager + or self.enable_ft_dataloader_checkpoints ): self.pg = dist.new_group(backend="gloo") @@ -480,14 +481,16 @@ def save(self, curr_step: int, last_step: bool = False) -> None: None """ - if self.ft_manager: + if self.enable_ft_dataloader_checkpoints: self._ft_save(curr_step) if not self._should_save(curr_step, last_step): return begin = time.monotonic() - if not self.ft_manager or self.ft_manager.participating_rank() == 0: + if not self.enable_ft_dataloader_checkpoints or ( + self.ft_manager and self.ft_manager.participating_rank() == 0 + ): logger.info("Saving the checkpoint (or staging if async is enabled).") checkpoint_id = self._create_checkpoint_id(curr_step) self._async_wait() @@ -530,7 +533,8 @@ def save(self, curr_step: int, last_step: bool = False) -> None: "Finished saving the checkpoint (or staging if async is enabled)" f"in {time.monotonic() - begin:.2f} seconds." ) - elif self.ft_manager: + elif self.enable_ft_dataloader_checkpoints: + assert self.ft_manager is not None logger.info( "Replica %d doesn't save checkpoint.", self.ft_manager.participating_rank(), @@ -551,7 +555,7 @@ def load(self, step: int = -1) -> bool: bool: Whether the checkpoint was loaded successfully. """ - if self.ft_manager: + if self.enable_ft_dataloader_checkpoints: self._ft_load() if not self.enable: @@ -749,7 +753,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: states_to_load = self._flattened_model_states_sd(states_to_load) - if self.ft_manager: + if self.enable_ft_dataloader_checkpoints: states_to_load.pop(DATALOADER) return states_to_load @@ -805,7 +809,9 @@ def _async_wait(self) -> None: if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: if self.save_future is not None: self.save_future.result() - elif self.async_mode == AsyncMode.ASYNC or self.ft_manager is not None: + elif ( + self.async_mode == AsyncMode.ASYNC or self.enable_ft_dataloader_checkpoints + ): if self.save_future is not None: self.save_future.result() self.save_future = None @@ -820,7 +826,10 @@ def _purge_stale_checkpoints(self): self.keep_latest_k > 0 and dist.get_rank() == 0 and os.path.isdir(self.folder) - and (not self.ft_manager or self.ft_manager.participating_rank() == 0) + and ( + not self.enable_ft_dataloader_checkpoints + or (self.ft_manager and self.ft_manager.participating_rank() == 0) + ) ): discovered_checkpoints = [] for filename in os.listdir(self.folder): diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 93a96a4439..b80beae8cc 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -232,7 +232,7 @@ def maybe_enable_amp( def init_distributed( - comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = "" + comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = "", ranks: list[int] = [] ): def _warn_overwrite_env(env, val): if env in os.environ: @@ -276,6 +276,7 @@ def _get_distributed_backend(enable_cpu_backend): torch.distributed.init_process_group( backend=_get_distributed_backend(enable_cpu_backend), timeout=timedelta(seconds=comm_config.init_timeout_seconds), + _ranks=ranks, ) diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index d638a6bd26..f3ce1e88d9 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -281,12 +281,13 @@ def train(self): self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}.") + torch_profiler = maybe_enable_profiling( + job_config.profiling, + global_step=self.step, + base_folder=job_config.job.dump_folder, + ) + with ( - maybe_enable_profiling( - job_config.profiling, - global_step=self.step, - base_folder=job_config.job.dump_folder, - ) as torch_profiler, maybe_enable_memory_snapshot( job_config.profiling, global_step=self.step, diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index f398dba9b5..049c780a73 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -18,7 +18,6 @@ MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 -@contextlib.contextmanager def maybe_enable_profiling( profiling_config: ProfilingConfig, *, @@ -68,7 +67,7 @@ def trace_handler(prof): gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA elif torch.xpu.is_available(): gpu_device_profiled = torch.profiler.ProfilerActivity.XPU - with torch.profiler.profile( + torch_profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, gpu_device_profiled, @@ -76,12 +75,12 @@ def trace_handler(prof): schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), on_trace_ready=trace_handler, record_shapes=True, - ) as torch_profiler: - torch_profiler.step_num = global_step - yield torch_profiler + ) + torch_profiler.step_num = global_step + torch_profiler.start() + return torch_profiler else: - torch_profiler = contextlib.nullcontext() - yield None + return None @contextlib.contextmanager diff --git a/torchtitan/train.py b/torchtitan/train.py index 2efd7931ed..37118c5acb 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,8 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import ctypes import importlib import os +import signal import time from datetime import timedelta from typing import Any, Generator, Iterable, Optional @@ -32,8 +34,12 @@ maybe_enable_profiling, ) +c_globals = ctypes.CDLL(None) # POSIX + class Trainer(torch.distributed.checkpoint.stateful.Stateful): + torch_profiler: torch.profiler.profile | None = None + # core configs job_config: JobConfig parallel_dims: ParallelDims @@ -83,11 +89,21 @@ def __init__(self, job_config: JobConfig): # Device has to be set before creating TorchFT manager. device_module.set_device(self.device) + ranks = [] + ft_config = job_config.fault_tolerance + if ft_config.enable: + group_size = ft_config.group_size + replica_id = ft_config.replica_id + first_rank = replica_id * group_size + last_rank = first_rank + group_size - 1 + ranks = list(range(first_rank, last_rank + 1)) + # init distributed and build meshes dist_utils.init_distributed( job_config.comm, enable_cpu_backend=job_config.training.enable_cpu_offload, base_folder=job_config.job.dump_folder, + ranks=ranks, ) job_config.maybe_log() @@ -570,13 +586,14 @@ def train(self): if not self.ft_manager.enabled else f"replica_{self.ft_manager.replica_id}" ) + self.torch_profiler = maybe_enable_profiling( + job_config.profiling, + global_step=self.step, + base_folder=job_config.job.dump_folder, + leaf_folder=leaf_folder, + ) + with ( - maybe_enable_profiling( - job_config.profiling, - global_step=self.step, - base_folder=job_config.job.dump_folder, - leaf_folder=leaf_folder, - ) as torch_profiler, maybe_enable_memory_snapshot( job_config.profiling, global_step=self.step, @@ -600,6 +617,15 @@ def train(self): ), ), ): + if self.torch_profiler: + + @ctypes.CFUNCTYPE(None, ctypes.c_int) + def sigabrt_handler(signal): + logger.info("SIGABRT received. Stopping profiler") + self.torch_profiler.export_chrome_trace("trace.json") + + c_globals.signal(signal.SIGABRT, sigabrt_handler) + data_iterator = self.batch_generator(self.dataloader) while self.should_continue_training(): self.step += 1 @@ -623,8 +649,8 @@ def train(self): self.validator.validate(self.model_parts, self.step) # signal the profiler that the next profiling step has started - if torch_profiler: - torch_profiler.step() + if self.torch_profiler: + self.torch_profiler.step() if memory_profiler: memory_profiler.step() @@ -633,7 +659,7 @@ def train(self): if self.step == 1: dist_utils.set_pg_timeouts( timeout=timedelta( - seconds=job_config.comm.train_timeout_seconds + milliseconds=job_config.comm.train_timeout_seconds ), world_mesh=self.parallel_dims.world_mesh, ) @@ -682,10 +708,12 @@ def close(self) -> None: else: trainer.train() except Exception: + logger.info("Torchtitan training threw an exception") if trainer: trainer.close() raise else: + logger.info("Torchtitan training completed") trainer.close() torch.distributed.destroy_process_group() logger.info("Process group destroyed")