Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions megatron/core/dist_checkpointing/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,4 @@ def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union[str,
"""Each async strategy can be trivially used as a sync strategy."""
async_request = self.async_save(sharded_state_dict, checkpoint_dir)
async_request.execute_sync()
del async_request
93 changes: 60 additions & 33 deletions megatron/core/dist_checkpointing/strategies/filesystem_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
*args,
separation_hint: Optional[str] = None,
use_msc: bool = False,
sequential: bool = False,
**kwargs,
):
self.checkpoint_dir = path
Expand All @@ -100,6 +101,7 @@ def __init__(
self.write_buckets: Optional[List[WriteBucket]] = None
self.results_queue: Optional[mp.Queue] = None
self.separation_hint = separation_hint
self.sequential = sequential

def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
"""
Expand Down Expand Up @@ -203,7 +205,9 @@ def get_save_function_and_args(self) -> Tuple[Optional[Callable], Optional[Calla
return None, None, []
transform_list = [self.transforms] if hasattr(self, "transforms") else []
return (
partial(self.write_preloaded_data_multiproc, transform_list, self.use_msc),
partial(
self.write_preloaded_data_multiproc, transform_list, self.use_msc, self.sequential
),
partial(self.preload_tensors, self.write_buckets, True),
[torch.distributed.get_rank(), self.write_buckets, self.results_queue],
)
Expand All @@ -222,10 +226,13 @@ def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List

for bucket in write_buckets:
file_name, storage_key, (bytes_data, tensor_data) = bucket
tensor_data = [
(item, tensor.to("cpu", non_blocking=non_blocking)) for item, tensor in tensor_data
]
result.append((file_name, storage_key, (bytes_data, tensor_data)))
tensor_list = []
for item, tensor in tensor_data:
# we belive these tensors are detached from the model trainers
tensor_list.append((item, tensor.to("cpu", non_blocking=non_blocking)))
# This is required for `PersistentAsyncCaller` to remove reference
del tensor
result.append((file_name, storage_key, (bytes_data, tensor_list)))
if non_blocking:
torch.cuda.synchronize()
return result
Expand All @@ -235,6 +242,7 @@ def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List
def write_preloaded_data_multiproc(
transform_list: List[_StorageWriterTransforms],
use_msc: bool,
sequential: bool,
rank: int,
write_buckets: List[WriteBucket],
global_results_queue: mp.Queue,
Expand Down Expand Up @@ -267,10 +275,18 @@ def write_preloaded_data_multiproc(
local_results_queue = ctx.Queue()
count_queue = ctx.JoinableQueue()
p_list = []

def check_local_output(local_results_or_exc, local_proc_idx):
if isinstance(local_results_or_exc, Exception):
err_msg = (
f"Local process {local_proc_idx} encountered"
f" an error: {local_results_or_exc}"
)
logger.error(err_msg)
assert isinstance(local_results_or_exc, list), type(local_results_or_exc)

for i, write_bucket in enumerate(write_buckets):
try:
count_queue.put(i)

kwargs = {
"local_proc_idx": i,
"write_bucket": write_bucket,
Expand All @@ -285,20 +301,38 @@ def write_preloaded_data_multiproc(
# Remove the inspect after the test_async_save.py is fixed.
signature = inspect.signature(FileSystemWriterAsync.write_preloaded_data)
if len(signature.parameters) > 6:
kwargs["use_msc"] = use_msc

p_list.append(
ctx.Process(
target=partial(FileSystemWriterAsync.write_preloaded_data, transform_list),
kwargs=kwargs,
kwargs['use_msc'] = use_msc
# Parallel Writers are required
if i < len(write_buckets) - 1 and not sequential:
count_queue.put(i)
p_list.append(
ctx.Process(
target=partial(
FileSystemWriterAsync.write_preloaded_data, transform_list
),
kwargs=kwargs,
)
)
)
else:
kwargs['count_queue'] = None
kwargs['results_queue'] = None
logger.debug('FileSystemWriterAsync: master worker started')
local_output = FileSystemWriterAsync.write_preloaded_data(
transform_list, **kwargs
)
if local_output is not None:
logger.debug(
'FileSystemWriterAsync: master worker results successfully collected'
)
check_local_output(local_output[1], local_output[0])
write_results_or_exc[local_output[0]] = local_output[1]

except Exception as e:
err_msg = f"An error is caught while a proc {i} is created, error: {e}"
logger.error(err_msg)
write_results_or_exc = RuntimeError(err_msg)

if not isinstance(write_results_or_exc, Exception):
if not isinstance(write_results_or_exc, Exception) and len(p_list) > 0 and not sequential:
for p in p_list:
p.start()

Expand All @@ -308,7 +342,7 @@ def write_preloaded_data_multiproc(
count_queue.join()
# At this point, all workers completed, so the queue should have exactly
# `len(write_buckets)` items
for proc_idx in range(len(write_buckets)):
for proc_idx in range(0, len(write_buckets) - 1):
try:
local_proc_idx, local_results_or_exc = local_results_queue.get()
except queue.Empty:
Expand All @@ -318,19 +352,10 @@ def write_preloaded_data_multiproc(
)
break
else:
if isinstance(local_results_or_exc, Exception):
err_msg = (
f"Local process {local_proc_idx} encountered"
f" an error: {local_results_or_exc}"
)
logger.error(err_msg)
write_results_or_exc = local_results_or_exc
break
assert isinstance(local_results_or_exc, list), type(local_results_or_exc)
check_local_output(local_results_or_exc, local_proc_idx)
write_results_or_exc[local_proc_idx] = local_results_or_exc
p_list[local_proc_idx].join()

logger.debug("FileSystemWriterAsync: collected worker results successfully")
logger.debug('FileSystemWriterAsync: collected worker results successfully')

global_results_queue.put(write_results_or_exc)

Expand All @@ -347,7 +372,7 @@ def write_preloaded_data(
count_queue: mp.JoinableQueue,
use_fsync: bool,
**kwargs,
) -> None:
) -> Union[Tuple[int, Exception], None]:
"""
Performs actual data saving to storage.

Expand Down Expand Up @@ -405,17 +430,19 @@ def write_preloaded_data(
except Exception as e:
logger.debug(f"{local_proc_idx} failed")
local_output = (local_proc_idx, e) # type: ignore[assignment]

results_queue.put(local_output)
# Signal this process is done.
count_queue.get()
count_queue.task_done()
if results_queue is not None:
results_queue.put(local_output)
if count_queue is not None:
# Signal this process is done.
count_queue.get()
count_queue.task_done()

mem_after = _process_memory()
logger.debug(
f"{local_proc_idx} consumed: {mem_after - mem_before},"
f" before: {mem_before}, after: {mem_after}"
)
return local_output

def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]:
"""Write all items from ``plan``."""
Expand Down
9 changes: 8 additions & 1 deletion megatron/core/dist_checkpointing/strategies/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def __init__(
backend: str,
version: int,
keep_only_main_replica: bool = True,
thread_count: int = 2,
thread_count: int = 1,
cached_metadata: bool = False,
separation_hint: Optional[str] = None,
):
Expand Down Expand Up @@ -658,12 +658,19 @@ def async_save(
)
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False)

sequential = False
if self.separation_hint is not None and self.thread_count <= 1:
self.thread_count = 2
sequential = True

# Use PyT saving mechanism
writer = FileSystemWriterAsync(
checkpoint_dir,
separation_hint=self.separation_hint,
thread_count=self.thread_count,
use_msc=MultiStorageClientFeature.is_enabled(),
sequential=sequential,
)
# This should be set differently if we run in a smaller process group than the default
coordinator = 0
Expand Down
10 changes: 10 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,12 @@ def validate_args(args, defaults={}):
'Disabling --async-save.'
)
args.async_save = False
elif args.dist_ckpt_workers > 1:
warn_rank_0(
'async ckpt forks processes for parallel writing which may introduce '
'instability on checkpoints. Consider using --dist-ckpt-workers=1 in case of '
'issues.'
)

# Inference args
if args.inference_batch_times_seqlen_threshold > -1:
Expand Down Expand Up @@ -2337,6 +2343,10 @@ def _add_checkpointing_args(parser):
group.add_argument('--dist-ckpt-format',
dest='dist_ckpt_format_deprecated',
help='Deprecated: see --ckpt-format.')
group.add_argument('--dist-ckpt-workers', type=int, default=1,
help='Number of workers for distributed checkpointing. '
'Only used for async save. '
'If set to 1, the checkpointing is performed in a single process.')
group.add_argument('--ckpt-fully-parallel-save', action='store_true',
dest='ckpt_fully_parallel_save_deprecated',
help='Deprecated: see --no-ckpt-fully-parallel-save.')
Expand Down
7 changes: 7 additions & 0 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,13 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
save_strategy = get_default_save_sharded_strategy(args.ckpt_format)
if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist':
save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure
if args.async_save:
save_strategy.thread_count = args.dist_ckpt_workers
else:
# We don't allow per-rank parallel save for sync save
logger.warning('Per-rank parallel save is not supported for sync save. '
'Setting args.dist_ckpt_workers to 1')
save_strategy.thread_count = 1
if checkpointing_context is not None and 'load_strategy' in checkpointing_context:
cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None)
if cached_global_metadata is not None:
Expand Down
3 changes: 1 addition & 2 deletions tests/unit_tests/dist_checkpointing/test_torch_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from megatron.core.dist_checkpointing import ShardedTensor, load, save
from megatron.core.dist_checkpointing.dict_utils import diff
from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils

Expand Down Expand Up @@ -117,7 +116,7 @@ def test_cpu_tensors_dont_take_too_much_space(self, tmp_path_dist_ckpt):
) as ckpt_dir:
save(sharded_state_dict, ckpt_dir)

distcp_files = [(ckpt_dir / '__0_0.distcp'), (ckpt_dir / '__0_1.distcp')]
distcp_files = [(ckpt_dir / '__0_0.distcp')]
for file in distcp_files:
assert file.exists()
file_size = file.stat().st_size
Expand Down
Loading