diff --git a/megatron/core/dist_checkpointing/strategies/base.py b/megatron/core/dist_checkpointing/strategies/base.py index 53422b362f6..e4d80517d93 100644 --- a/megatron/core/dist_checkpointing/strategies/base.py +++ b/megatron/core/dist_checkpointing/strategies/base.py @@ -213,3 +213,4 @@ def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union[str, """Each async strategy can be trivially used as a sync strategy.""" async_request = self.async_save(sharded_state_dict, checkpoint_dir) async_request.execute_sync() + del async_request diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py index b23c4e9893d..e1129a69593 100644 --- a/megatron/core/dist_checkpointing/strategies/filesystem_async.py +++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py @@ -8,6 +8,7 @@ import os import pickle import queue +import threading from functools import partial from heapq import heappop, heappush from itertools import chain @@ -67,7 +68,7 @@ class FileSystemWriterAsync(FileSystemWriter): 1. Call `write_data` 2. Externally start an async process with `get_save_function_and_args` and its arguments. 3. The async function `writer_proxy_func` calls `write_preloaded_data` across multiple - processes. + threads (no child processes). 4. Once saving is finalized on all ranks, call `super().finish` with the results stored in `self.writer_result`. @@ -150,7 +151,7 @@ def _clone_if_needed(ten: torch.Tensor): is_view = ten.untyped_storage().size() != ten.numel() * ten.itemsize return ten.clone() if is_view else ten - # Prepare bytes / tensor data in each bucket, which will be assigned to each writer process + # Prepare bytes / tensor data in each bucket, which will be assigned to each writer thread self.write_buckets = [] for group_name, group_buckets in _split_by_separation_hint( item_buckets, self.separation_hint @@ -203,7 +204,7 @@ def get_save_function_and_args(self) -> Tuple[Optional[Callable], Optional[Calla return None, None, [] transform_list = [self.transforms] if hasattr(self, "transforms") else [] return ( - partial(self.write_preloaded_data_multiproc, transform_list, self.use_msc), + partial(self.write_preloaded_data_multithread, transform_list, self.use_msc), partial(self.preload_tensors, self.write_buckets, True), [torch.distributed.get_rank(), self.write_buckets, self.results_queue], ) @@ -222,17 +223,20 @@ def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List for bucket in write_buckets: file_name, storage_key, (bytes_data, tensor_data) = bucket - tensor_data = [ - (item, tensor.to("cpu", non_blocking=non_blocking)) for item, tensor in tensor_data - ] - result.append((file_name, storage_key, (bytes_data, tensor_data))) + tensor_list = [] + for item, tensor in tensor_data: + # we belive these tensors are detached from the model trainers + tensor_list.append((item, tensor.to("cpu", non_blocking=non_blocking))) + # This is required for `PersistentAsyncCaller` to remove reference + del tensor + result.append((file_name, storage_key, (bytes_data, tensor_list))) if non_blocking: torch.cuda.synchronize() return result @staticmethod @_disable_gc() - def write_preloaded_data_multiproc( + def write_preloaded_data_multithread( transform_list: List[_StorageWriterTransforms], use_msc: bool, rank: int, @@ -240,37 +244,40 @@ def write_preloaded_data_multiproc( global_results_queue: mp.Queue, ) -> None: """ - Performs saving data to storage with multiple processes. + Performs saving data to storage with multiple threads. - Starts predefined number of processes and uses 2 queues to make sure the results - are complete: - - local_results_queue - to send the actual results - - count_queue - small queue to mark worker as completed + Uses threads (not processes) so that this can run safely inside a daemon process + without spawning child processes. Uses two queues: + - local_results_queue - to collect write results from worker threads + - count_queue - to signal worker completion (task_done/join). - Using just one queue disallowed proper exception handling. - - This method is meant to be run in a forked subprocess. - Triggering GC during execution leads to CUDA errors - (cleaning up tensors owned by the parent process). + Triggering GC during execution can lead to CUDA errors when tensors are shared. To prevent this, we disable the GC explicitly for this function with _disable_gc. Args: write_buckets (List[WriteBucket]): write plan - global_results_queue (mp.Queue): mp.Queue to collect Dict[List[WriteResults]] - (or an Exception) from parallel write processes to the main training process + global_results_queue (mp.Queue): queue to send Dict[List[WriteResults]] + (or an Exception) back to the main training process Returns: None """ logger = logging.getLogger(__name__) w_start = time() write_results_or_exc: Union[dict, Exception] = dict() - ctx = mp.get_context("fork") - local_results_queue = ctx.Queue() - count_queue = ctx.JoinableQueue() - p_list = [] + local_results_queue: queue.Queue = queue.Queue() + count_queue: queue.Queue = queue.Queue() + thread_list: List[threading.Thread] = [] + + def check_local_output(local_results_or_exc, local_worker_idx): + if isinstance(local_results_or_exc, Exception): + err_msg = ( + f"Local worker {local_worker_idx} encountered" + f" an error: {local_results_or_exc}" + ) + logger.error(err_msg) + assert isinstance(local_results_or_exc, list), type(local_results_or_exc) + for i, write_bucket in enumerate(write_buckets): try: - count_queue.put(i) - kwargs = { "local_proc_idx": i, "write_bucket": write_bucket, @@ -280,62 +287,63 @@ def write_preloaded_data_multiproc( } if use_msc: - import inspect - - # Remove the inspect after the test_async_save.py is fixed. signature = inspect.signature(FileSystemWriterAsync.write_preloaded_data) if len(signature.parameters) > 6: - kwargs["use_msc"] = use_msc - - p_list.append( - ctx.Process( + kwargs['use_msc'] = use_msc + # Parallel writers: spawn threads for all but the last bucket + if i < len(write_buckets) - 1: + count_queue.put(i) + t = threading.Thread( target=partial(FileSystemWriterAsync.write_preloaded_data, transform_list), kwargs=kwargs, ) - ) + thread_list.append(t) + else: + kwargs['count_queue'] = None + kwargs['results_queue'] = None + logger.debug('FileSystemWriterAsync: main worker started') + local_output = FileSystemWriterAsync.write_preloaded_data( + transform_list, **kwargs + ) + if local_output is not None: + logger.debug( + 'FileSystemWriterAsync: main worker results successfully collected' + ) + check_local_output(local_output[1], local_output[0]) + write_results_or_exc[local_output[0]] = local_output[1] + except Exception as e: - err_msg = f"An error is caught while a proc {i} is created, error: {e}" + err_msg = f"An error is caught while starting worker {i}, error: {e}" logger.error(err_msg) write_results_or_exc = RuntimeError(err_msg) - if not isinstance(write_results_or_exc, Exception): - for p in p_list: - p.start() + if not isinstance(write_results_or_exc, Exception) and len(thread_list) > 0: + for t in thread_list: + t.start() logger.debug("FileSystemWriterAsync: collecting worker results...") - # To make sure all nodes are completed count_queue.join() - # At this point, all workers completed, so the queue should have exactly - # `len(write_buckets)` items - for proc_idx in range(len(write_buckets)): + for _ in range(len(write_buckets) - 1): try: local_proc_idx, local_results_or_exc = local_results_queue.get() except queue.Empty: write_results_or_exc = RuntimeError( "Unexpected empty `local_results_queue`" - f" (got only {proc_idx}/{len(write_buckets)} items)" + f" (expected {len(write_buckets) - 1} items)" ) break else: - if isinstance(local_results_or_exc, Exception): - err_msg = ( - f"Local process {local_proc_idx} encountered" - f" an error: {local_results_or_exc}" - ) - logger.error(err_msg) - write_results_or_exc = local_results_or_exc - break - assert isinstance(local_results_or_exc, list), type(local_results_or_exc) + check_local_output(local_results_or_exc, local_proc_idx) write_results_or_exc[local_proc_idx] = local_results_or_exc - p_list[local_proc_idx].join() - - logger.debug("FileSystemWriterAsync: collected worker results successfully") + for t in thread_list: + t.join() + logger.debug('FileSystemWriterAsync: collected worker results successfully') global_results_queue.put(write_results_or_exc) w_end = time() - logger.debug(f"{w_end}, rank: {rank}, write(sync,parallel): {w_end - w_start}") + logger.debug(f"{w_end}, rank: {rank}, write(sync,threads): {w_end - w_start}") @staticmethod @_disable_gc() @@ -343,23 +351,22 @@ def write_preloaded_data( transform_list: List[_StorageWriterTransforms], local_proc_idx: int, write_bucket: WriteBucket, - results_queue: mp.SimpleQueue, - count_queue: mp.JoinableQueue, + results_queue: Optional[queue.Queue], + count_queue: Optional[queue.Queue], use_fsync: bool, **kwargs, - ) -> None: + ) -> Union[Tuple[int, Exception], None]: """ - Performs actual data saving to storage. + Performs actual data saving to storage (used by worker threads). Args: - local_proc_idx (int): index of a local process that performs writing + local_proc_idx (int): index of the worker that performs writing write_bucket (WriteBucket): data to write to storage - results_queue (mp.Queue): queue to return the write results - to the proxy checkpoint process. - count_queue (mp.JoinableQueue): queue to marks worker task as completed + results_queue (queue.Queue): queue to return the write results + count_queue (queue.Queue): queue to signal worker task completion (get + task_done) use_fsync (bool): if True, calls os.fsync at the end of saving - Returns: None, the write result are put into the `queue` + Returns: None when running in a worker (results put in queue); result tuple when main worker """ logger = logging.getLogger(__name__) logger.debug(f"{local_proc_idx} started") @@ -405,17 +412,19 @@ def write_preloaded_data( except Exception as e: logger.debug(f"{local_proc_idx} failed") local_output = (local_proc_idx, e) # type: ignore[assignment] - - results_queue.put(local_output) - # Signal this process is done. - count_queue.get() - count_queue.task_done() + if results_queue is not None: + results_queue.put(local_output) + if count_queue is not None: + # Signal this process is done. + count_queue.get() + count_queue.task_done() mem_after = _process_memory() logger.debug( f"{local_proc_idx} consumed: {mem_after - mem_before}," f" before: {mem_before}, after: {mem_after}" ) + return local_output def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]: """Write all items from ``plan``.""" @@ -427,7 +436,7 @@ def retrieve_write_results(self) -> Union[List[WriteResult], WRAPPED_EXCEPTION]: into a single results lists. Includes error check. Returns (Union(List[WriteResult], WRAPPED_EXCEPTION): the list of write results - from all local processes performing the save, or a WRAPPED_EXCEPTION if + from all local workers (threads) performing the save, or a WRAPPED_EXCEPTION if an exception was raised during the writing process. """ assert self.write_buckets is not None diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index 6baebb0db99..645ea4be2c3 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -599,7 +599,7 @@ def __init__( backend: str, version: int, keep_only_main_replica: bool = True, - thread_count: int = 2, + thread_count: int = 1, cached_metadata: bool = False, separation_hint: Optional[str] = None, ): @@ -658,6 +658,10 @@ def async_save( ) ) pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False) + + if self.separation_hint is not None and self.thread_count <= 1: + self.thread_count = 2 + # Use PyT saving mechanism writer = FileSystemWriterAsync( checkpoint_dir, diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 20e749dbc21..b3ffbc59b51 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1391,7 +1391,7 @@ def validate_args(args, defaults={}): 'Disabling --async-save.' ) args.async_save = False - + # Inference args if args.inference_batch_times_seqlen_threshold > -1: assert args.pipeline_model_parallel_size > 1, \ @@ -2448,6 +2448,10 @@ def _add_checkpointing_args(parser): group.add_argument('--dist-ckpt-format', dest='dist_ckpt_format_deprecated', help='Deprecated: see --ckpt-format.') + group.add_argument('--dist-ckpt-workers', type=int, default=1, + help='Number of workers for distributed checkpointing. ' + 'Only used for async save. ' + 'If set to 1, the checkpointing is performed in a single process.') group.add_argument('--ckpt-fully-parallel-save', action='store_true', dest='ckpt_fully_parallel_save_deprecated', help='Deprecated: see --no-ckpt-fully-parallel-save.') diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 44a0a70b6d4..f3592069afe 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -616,6 +616,13 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati save_strategy = get_default_save_sharded_strategy(args.ckpt_format) if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist': save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure + if args.async_save: + save_strategy.thread_count = args.dist_ckpt_workers + else: + # We don't allow per-rank parallel save for sync save + logger.warning('Per-rank parallel save is not supported for sync save. ' + 'Setting args.dist_ckpt_workers to 1') + save_strategy.thread_count = 1 if checkpointing_context is not None and 'load_strategy' in checkpointing_context: cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None) if cached_global_metadata is not None: diff --git a/tests/unit_tests/dist_checkpointing/test_torch_dist.py b/tests/unit_tests/dist_checkpointing/test_torch_dist.py index 4f4df058977..64a47b8cbb3 100644 --- a/tests/unit_tests/dist_checkpointing/test_torch_dist.py +++ b/tests/unit_tests/dist_checkpointing/test_torch_dist.py @@ -11,7 +11,6 @@ from megatron.core.dist_checkpointing import ShardedTensor, load, save from megatron.core.dist_checkpointing.dict_utils import diff from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy -from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue from tests.unit_tests.dist_checkpointing import TempNamedDir from tests.unit_tests.test_utilities import Utils @@ -117,7 +116,7 @@ def test_cpu_tensors_dont_take_too_much_space(self, tmp_path_dist_ckpt): ) as ckpt_dir: save(sharded_state_dict, ckpt_dir) - distcp_files = [(ckpt_dir / '__0_0.distcp'), (ckpt_dir / '__0_1.distcp')] + distcp_files = [(ckpt_dir / '__0_0.distcp')] for file in distcp_files: assert file.exists() file_size = file.stat().st_size