Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions megatron/core/dist_checkpointing/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,4 @@ def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union[str,
"""Each async strategy can be trivially used as a sync strategy."""
async_request = self.async_save(sharded_state_dict, checkpoint_dir)
async_request.execute_sync()
del async_request
151 changes: 80 additions & 71 deletions megatron/core/dist_checkpointing/strategies/filesystem_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import pickle
import queue
import threading
from functools import partial
from heapq import heappop, heappush
from itertools import chain
Expand Down Expand Up @@ -67,7 +68,7 @@ class FileSystemWriterAsync(FileSystemWriter):
1. Call `write_data`
2. Externally start an async process with `get_save_function_and_args` and its arguments.
3. The async function `writer_proxy_func` calls `write_preloaded_data` across multiple
processes.
threads (no child processes).
4. Once saving is finalized on all ranks, call `super().finish` with the results stored
in `self.writer_result`.

Expand Down Expand Up @@ -150,7 +151,7 @@ def _clone_if_needed(ten: torch.Tensor):
is_view = ten.untyped_storage().size() != ten.numel() * ten.itemsize
return ten.clone() if is_view else ten

# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
# Prepare bytes / tensor data in each bucket, which will be assigned to each writer thread
self.write_buckets = []
for group_name, group_buckets in _split_by_separation_hint(
item_buckets, self.separation_hint
Expand Down Expand Up @@ -203,7 +204,7 @@ def get_save_function_and_args(self) -> Tuple[Optional[Callable], Optional[Calla
return None, None, []
transform_list = [self.transforms] if hasattr(self, "transforms") else []
return (
partial(self.write_preloaded_data_multiproc, transform_list, self.use_msc),
partial(self.write_preloaded_data_multithread, transform_list, self.use_msc),
partial(self.preload_tensors, self.write_buckets, True),
[torch.distributed.get_rank(), self.write_buckets, self.results_queue],
)
Expand All @@ -222,55 +223,61 @@ def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List

for bucket in write_buckets:
file_name, storage_key, (bytes_data, tensor_data) = bucket
tensor_data = [
(item, tensor.to("cpu", non_blocking=non_blocking)) for item, tensor in tensor_data
]
result.append((file_name, storage_key, (bytes_data, tensor_data)))
tensor_list = []
for item, tensor in tensor_data:
# we belive these tensors are detached from the model trainers
tensor_list.append((item, tensor.to("cpu", non_blocking=non_blocking)))
# This is required for `PersistentAsyncCaller` to remove reference
del tensor
result.append((file_name, storage_key, (bytes_data, tensor_list)))
if non_blocking:
torch.cuda.synchronize()
return result

@staticmethod
@_disable_gc()
def write_preloaded_data_multiproc(
def write_preloaded_data_multithread(
transform_list: List[_StorageWriterTransforms],
use_msc: bool,
rank: int,
write_buckets: List[WriteBucket],
global_results_queue: mp.Queue,
) -> None:
"""
Performs saving data to storage with multiple processes.
Performs saving data to storage with multiple threads.

Starts predefined number of processes and uses 2 queues to make sure the results
are complete:
- local_results_queue - to send the actual results
- count_queue - small queue to mark worker as completed
Uses threads (not processes) so that this can run safely inside a daemon process
without spawning child processes. Uses two queues:
- local_results_queue - to collect write results from worker threads
- count_queue - to signal worker completion (task_done/join).

Using just one queue disallowed proper exception handling.

This method is meant to be run in a forked subprocess.
Triggering GC during execution leads to CUDA errors
(cleaning up tensors owned by the parent process).
Triggering GC during execution can lead to CUDA errors when tensors are shared.
To prevent this, we disable the GC explicitly for this function with _disable_gc.

Args:
write_buckets (List[WriteBucket]): write plan
global_results_queue (mp.Queue): mp.Queue to collect Dict[List[WriteResults]]
(or an Exception) from parallel write processes to the main training process
global_results_queue (mp.Queue): queue to send Dict[List[WriteResults]]
(or an Exception) back to the main training process
Returns: None
"""
logger = logging.getLogger(__name__)
w_start = time()
write_results_or_exc: Union[dict, Exception] = dict()
ctx = mp.get_context("fork")
local_results_queue = ctx.Queue()
count_queue = ctx.JoinableQueue()
p_list = []
local_results_queue: queue.Queue = queue.Queue()
count_queue: queue.Queue = queue.Queue()
thread_list: List[threading.Thread] = []

def check_local_output(local_results_or_exc, local_worker_idx):
if isinstance(local_results_or_exc, Exception):
err_msg = (
f"Local worker {local_worker_idx} encountered"
f" an error: {local_results_or_exc}"
)
logger.error(err_msg)
assert isinstance(local_results_or_exc, list), type(local_results_or_exc)

for i, write_bucket in enumerate(write_buckets):
try:
count_queue.put(i)

kwargs = {
"local_proc_idx": i,
"write_bucket": write_bucket,
Expand All @@ -280,86 +287,86 @@ def write_preloaded_data_multiproc(
}

if use_msc:
import inspect

# Remove the inspect after the test_async_save.py is fixed.
signature = inspect.signature(FileSystemWriterAsync.write_preloaded_data)
if len(signature.parameters) > 6:
kwargs["use_msc"] = use_msc

p_list.append(
ctx.Process(
kwargs['use_msc'] = use_msc
# Parallel writers: spawn threads for all but the last bucket
if i < len(write_buckets) - 1:
count_queue.put(i)
t = threading.Thread(
target=partial(FileSystemWriterAsync.write_preloaded_data, transform_list),
kwargs=kwargs,
)
)
thread_list.append(t)
else:
kwargs['count_queue'] = None
kwargs['results_queue'] = None
logger.debug('FileSystemWriterAsync: main worker started')
local_output = FileSystemWriterAsync.write_preloaded_data(
transform_list, **kwargs
)
if local_output is not None:
logger.debug(
'FileSystemWriterAsync: main worker results successfully collected'
)
check_local_output(local_output[1], local_output[0])
write_results_or_exc[local_output[0]] = local_output[1]

except Exception as e:
err_msg = f"An error is caught while a proc {i} is created, error: {e}"
err_msg = f"An error is caught while starting worker {i}, error: {e}"
logger.error(err_msg)
write_results_or_exc = RuntimeError(err_msg)

if not isinstance(write_results_or_exc, Exception):
for p in p_list:
p.start()
if not isinstance(write_results_or_exc, Exception) and len(thread_list) > 0:
for t in thread_list:
t.start()

logger.debug("FileSystemWriterAsync: collecting worker results...")

# To make sure all nodes are completed
count_queue.join()
# At this point, all workers completed, so the queue should have exactly
# `len(write_buckets)` items
for proc_idx in range(len(write_buckets)):
for _ in range(len(write_buckets) - 1):
try:
local_proc_idx, local_results_or_exc = local_results_queue.get()
except queue.Empty:
write_results_or_exc = RuntimeError(
"Unexpected empty `local_results_queue`"
f" (got only {proc_idx}/{len(write_buckets)} items)"
f" (expected {len(write_buckets) - 1} items)"
)
break
else:
if isinstance(local_results_or_exc, Exception):
err_msg = (
f"Local process {local_proc_idx} encountered"
f" an error: {local_results_or_exc}"
)
logger.error(err_msg)
write_results_or_exc = local_results_or_exc
break
assert isinstance(local_results_or_exc, list), type(local_results_or_exc)
check_local_output(local_results_or_exc, local_proc_idx)
write_results_or_exc[local_proc_idx] = local_results_or_exc
p_list[local_proc_idx].join()

logger.debug("FileSystemWriterAsync: collected worker results successfully")
for t in thread_list:
t.join()
logger.debug('FileSystemWriterAsync: collected worker results successfully')

global_results_queue.put(write_results_or_exc)

w_end = time()
logger.debug(f"{w_end}, rank: {rank}, write(sync,parallel): {w_end - w_start}")
logger.debug(f"{w_end}, rank: {rank}, write(sync,threads): {w_end - w_start}")

@staticmethod
@_disable_gc()
def write_preloaded_data(
transform_list: List[_StorageWriterTransforms],
local_proc_idx: int,
write_bucket: WriteBucket,
results_queue: mp.SimpleQueue,
count_queue: mp.JoinableQueue,
results_queue: Optional[queue.Queue],
count_queue: Optional[queue.Queue],
use_fsync: bool,
**kwargs,
) -> None:
) -> Union[Tuple[int, Exception], None]:
"""
Performs actual data saving to storage.
Performs actual data saving to storage (used by worker threads).

Args:
local_proc_idx (int): index of a local process that performs writing
local_proc_idx (int): index of the worker that performs writing
write_bucket (WriteBucket): data to write to storage
results_queue (mp.Queue): queue to return the write results
to the proxy checkpoint process.
count_queue (mp.JoinableQueue): queue to marks worker task as completed
results_queue (queue.Queue): queue to return the write results
count_queue (queue.Queue): queue to signal worker task completion (get + task_done)
use_fsync (bool): if True, calls os.fsync at the end of saving

Returns: None, the write result are put into the `queue`
Returns: None when running in a worker (results put in queue); result tuple when main worker
"""
logger = logging.getLogger(__name__)
logger.debug(f"{local_proc_idx} started")
Expand Down Expand Up @@ -405,17 +412,19 @@ def write_preloaded_data(
except Exception as e:
logger.debug(f"{local_proc_idx} failed")
local_output = (local_proc_idx, e) # type: ignore[assignment]

results_queue.put(local_output)
# Signal this process is done.
count_queue.get()
count_queue.task_done()
if results_queue is not None:
results_queue.put(local_output)
if count_queue is not None:
# Signal this process is done.
count_queue.get()
count_queue.task_done()

mem_after = _process_memory()
logger.debug(
f"{local_proc_idx} consumed: {mem_after - mem_before},"
f" before: {mem_before}, after: {mem_after}"
)
return local_output

def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]:
"""Write all items from ``plan``."""
Expand All @@ -427,7 +436,7 @@ def retrieve_write_results(self) -> Union[List[WriteResult], WRAPPED_EXCEPTION]:
into a single results lists. Includes error check.

Returns (Union(List[WriteResult], WRAPPED_EXCEPTION): the list of write results
from all local processes performing the save, or a WRAPPED_EXCEPTION if
from all local workers (threads) performing the save, or a WRAPPED_EXCEPTION if
an exception was raised during the writing process.
"""
assert self.write_buckets is not None
Expand Down
6 changes: 5 additions & 1 deletion megatron/core/dist_checkpointing/strategies/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def __init__(
backend: str,
version: int,
keep_only_main_replica: bool = True,
thread_count: int = 2,
thread_count: int = 1,
cached_metadata: bool = False,
separation_hint: Optional[str] = None,
):
Expand Down Expand Up @@ -658,6 +658,10 @@ def async_save(
)
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False)

if self.separation_hint is not None and self.thread_count <= 1:
self.thread_count = 2

# Use PyT saving mechanism
writer = FileSystemWriterAsync(
checkpoint_dir,
Expand Down
6 changes: 5 additions & 1 deletion megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ def validate_args(args, defaults={}):
'Disabling --async-save.'
)
args.async_save = False

# Inference args
if args.inference_batch_times_seqlen_threshold > -1:
assert args.pipeline_model_parallel_size > 1, \
Expand Down Expand Up @@ -2448,6 +2448,10 @@ def _add_checkpointing_args(parser):
group.add_argument('--dist-ckpt-format',
dest='dist_ckpt_format_deprecated',
help='Deprecated: see --ckpt-format.')
group.add_argument('--dist-ckpt-workers', type=int, default=1,
help='Number of workers for distributed checkpointing. '
'Only used for async save. '
'If set to 1, the checkpointing is performed in a single process.')
group.add_argument('--ckpt-fully-parallel-save', action='store_true',
dest='ckpt_fully_parallel_save_deprecated',
help='Deprecated: see --no-ckpt-fully-parallel-save.')
Expand Down
7 changes: 7 additions & 0 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,13 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
save_strategy = get_default_save_sharded_strategy(args.ckpt_format)
if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist':
save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure
if args.async_save:
save_strategy.thread_count = args.dist_ckpt_workers
else:
# We don't allow per-rank parallel save for sync save
logger.warning('Per-rank parallel save is not supported for sync save. '
'Setting args.dist_ckpt_workers to 1')
save_strategy.thread_count = 1
if checkpointing_context is not None and 'load_strategy' in checkpointing_context:
cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None)
if cached_global_metadata is not None:
Expand Down
3 changes: 1 addition & 2 deletions tests/unit_tests/dist_checkpointing/test_torch_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from megatron.core.dist_checkpointing import ShardedTensor, load, save
from megatron.core.dist_checkpointing.dict_utils import diff
from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils

Expand Down Expand Up @@ -117,7 +116,7 @@ def test_cpu_tensors_dont_take_too_much_space(self, tmp_path_dist_ckpt):
) as ckpt_dir:
save(sharded_state_dict, ckpt_dir)

distcp_files = [(ckpt_dir / '__0_0.distcp'), (ckpt_dir / '__0_1.distcp')]
distcp_files = [(ckpt_dir / '__0_0.distcp')]
for file in distcp_files:
assert file.exists()
file_size = file.stat().st_size
Expand Down
Loading