From 31da86de1eaa8fad028cd452dfde2894c5704800 Mon Sep 17 00:00:00 2001 From: Shane A Date: Wed, 22 Nov 2023 11:05:37 -0500 Subject: [PATCH 01/13] Add unsharder parser --- scripts/storage_cleaner.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 582d6b5ed..ce1daebbf 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -28,11 +28,13 @@ log = logging.getLogger(__name__) -DEFAULT_MAX_ARCHIVE_SIZE: float = 5 * 1024 * 1024 * 1024 # 5GB +DEFAULT_DELETE_MAX_ARCHIVE_SIZE: float = 5 * 1024 * 1024 * 1024 # 5GB +UNSHARD_SCRIPT_PATH: str = "scripts/unshard.py" class CleaningOperations(Enum): DELETE_BAD_RUNS = auto() + UNSHARD_CHECKPOINTS = auto() class StorageType(Enum): @@ -846,11 +848,35 @@ def _add_delete_subparser(subparsers: _SubParsersAction): delete_runs_parser.add_argument( "--max_archive_size", - default=DEFAULT_MAX_ARCHIVE_SIZE, + default=DEFAULT_DELETE_MAX_ARCHIVE_SIZE, help="Max size archive files to consider for deletion (in bytes). Any archive larger than this is ignored/not deleted.", ) +def _add_unsharding_subparser(subparsers: _SubParsersAction): + unsharding_runs_parser: ArgumentParser = subparsers.add_parser("unshard", help="unshard checkpoints of a run") + unsharding_runs_parser.set_defaults(op=CleaningOperations.UNSHARD_CHECKPOINTS) + + unsharding_runs_parser.add_argument( + "run_path", + help="Path to run directory or archive containing checkpoints to unshard.", + ) + unsharding_runs_parser.add_argument( + "dest_dir", + help="Path to directory where the run's unsharded checkpoints should be output (only the unsharded checkpoints are stored).", + ) + unsharding_runs_parser.add_argument( + "--latest_checkpoint_only", + action="store_true", + help="If set, only the latest checkpoint of each run (if sharded) is unsharded.", + ) + unsharding_runs_parser.add_argument( + "--script_path", + default=UNSHARD_SCRIPT_PATH, + help=f"Path of the unsharder script. Set to `{UNSHARD_SCRIPT_PATH}` by default.", + ) + + def get_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument( @@ -862,6 +888,7 @@ def get_parser() -> ArgumentParser: subparsers = parser.add_subparsers(dest="command", help="Cleaning commands", required=True) _add_delete_subparser(subparsers) + _add_unsharding_subparser(subparsers) return parser From 182430a0a4047760d8735f3cd1ac1373d60ec95f Mon Sep 17 00:00:00 2001 From: Shane A Date: Wed, 22 Nov 2023 11:07:23 -0500 Subject: [PATCH 02/13] Add and invoke unsharding method skeleton --- scripts/storage_cleaner.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index ce1daebbf..304a5f337 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -28,6 +28,7 @@ log = logging.getLogger(__name__) +CONFIG_YAML: str = "config.yaml" DEFAULT_DELETE_MAX_ARCHIVE_SIZE: float = 5 * 1024 * 1024 * 1024 # 5GB UNSHARD_SCRIPT_PATH: str = "scripts/unshard.py" @@ -646,6 +647,13 @@ class DeleteBadRunsConfig: max_archive_size: Optional[int] +@dataclass +class UnshardCheckpointsConfig: + dry_run: bool + unshard_script_path: Path + latest_checkpoint_only: bool + + def _get_storage_adapter_for_path(path: str) -> StorageAdapter: storage_type = StorageAdapter.get_storage_type_for_path(path) return StorageAdapter.create_storage_adapter(storage_type) @@ -764,6 +772,10 @@ def delete_bad_runs(run_paths: List[str], config: DeleteBadRunsConfig): _delete_if_bad_run(storage, run_path, config) +def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: UnshardCheckpointsConfig): + raise NotImplementedError() + + def perform_operation(args: argparse.Namespace): if args.dry_run: log.info("Dry run, no irreversible actions will be taken") @@ -779,6 +791,16 @@ def perform_operation(args: argparse.Namespace): delete_bad_runs(args.run_paths, delete_bad_runs_config) else: raise ValueError("Run paths not provided for run cleaning") + elif args.op == CleaningOperations.UNSHARD_CHECKPOINTS: + unshard_checkpoints_config = UnshardCheckpointsConfig( + dry_run=args.dry_run, + unshard_script_path=args.script_path, + latest_checkpoint_only=args.latest_checkpoint_only, + ) + if args.run_path is not None: + unshard_run_checkpoints(args.run_path, args.dest_dir, unshard_checkpoints_config) + else: + raise ValueError("Run path not provided for unsharding") else: raise NotImplementedError(args.op) From 6067be4a8b3310c2bcc087eedbec6370804e5fd1 Mon Sep 17 00:00:00 2001 From: Shane A Date: Wed, 22 Nov 2023 11:09:12 -0500 Subject: [PATCH 03/13] Add core unsharding logic --- scripts/storage_cleaner.py | 120 ++++++++++++++++++++++++++++++++++++- 1 file changed, 119 insertions(+), 1 deletion(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 304a5f337..b8863383d 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -3,6 +3,7 @@ import os import re import shutil +import subprocess import tempfile from abc import ABC, abstractmethod from argparse import ArgumentParser, _SubParsersAction @@ -772,8 +773,125 @@ def delete_bad_runs(run_paths: List[str], config: DeleteBadRunsConfig): _delete_if_bad_run(storage, run_path, config) +def _is_sharded_checkpoint_dir(directory: str) -> bool: + storage = _get_storage_adapter_for_path(directory) + return storage.is_dir(directory) and re.match(r"step\d+$", Path(directory).name) is not None + + +def _get_checkpoint_number(checkpoint_dir: str) -> int: + checkpoint_dir_name = Path(checkpoint_dir).name + checkpoint_dir_name = checkpoint_dir_name.removesuffix("-unsharded") + match = re.match(r"step(\d+)$", checkpoint_dir_name) + if match is None: + raise ValueError(f"Failed to find checkpoint number for dir {checkpoint_dir}") + + return int(match.group(1)) + + +def _get_sharded_checkpoint_dirs( + storage: StorageAdapter, run_dir_or_archive: str, latest_checkpoint_only: bool +) -> List[str]: + run_subdirectories = _get_run_entries(run_dir_or_archive, storage, full_path=True) + sharded_checkpoint_directories = list( + filter(_is_sharded_checkpoint_dir, run_subdirectories) + ) + + if latest_checkpoint_only: + latest_checkpoint_directory = max(sharded_checkpoint_directories, default=None, key=_get_checkpoint_number) + sharded_checkpoint_directories = ( + [latest_checkpoint_directory] if latest_checkpoint_directory is not None else [] + ) + + log.info("Found %d sharded checkpoint directories for %s", len(sharded_checkpoint_directories), run_dir_or_archive) + + return sharded_checkpoint_directories + + +def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str, run_dir_or_archive: str, unsharding_config: UnshardCheckpointsConfig): + local_storage = LocalFileSystemAdapter() + + # Download checkpoint to a temp dir + sharding_input_dir = local_storage.create_temp_dir() + src_storage = _get_storage_adapter_for_path(sharded_checkpoint_dir) + src_storage.download_folder(sharded_checkpoint_dir, sharding_input_dir) + + # Set unsharder output to a temp dir + sharding_output_dir: str + sharding_output_dir = local_storage.create_temp_dir() + + result = subprocess.run( + ["python", str(unsharding_config.unshard_script_path), sharding_input_dir, sharding_output_dir], + check=False, + ) + if result.returncode != 0: + log.error( + "Unsharding from %s to %s failed with error code %d", + sharding_input_dir, + sharding_output_dir, + result.returncode, + ) + + local_storage.delete_path(sharding_output_dir) + return + + log.info( + "Successfully unsharded from %s to %s, starting upload to %s", + sharding_input_dir, + sharding_output_dir, + dest_dir, + ) + + dest_storage = _get_storage_adapter_for_path(dest_dir) + dest_storage.upload(sharding_output_dir, dest_dir) + + +def _unshard_checkpoints( + run_storage: StorageAdapter, + run_dir_or_archive: str, + checkpoints_dest_dir: str, + config: UnshardCheckpointsConfig, +): + log.info("Starting unsharding checkpoints of run directory or archive %s", run_dir_or_archive) + + sharded_checkpoint_directories = _get_sharded_checkpoint_dirs( + run_storage, run_dir_or_archive, config.latest_checkpoint_only + ) + for sharded_checkpoint_directory in sharded_checkpoint_directories: + sharded_checkpoint_dir_name = Path(sharded_checkpoint_directory).name + + if run_storage.is_dir(run_dir_or_archive): + unsharded_checkpoint_directory_in_source = os.path.join( + run_dir_or_archive, f"{sharded_checkpoint_dir_name}-unsharded" + ) + if run_storage.is_dir(unsharded_checkpoint_directory_in_source): + log.info( + "Unsharded directory already exists for %s at source %s, skipping", + sharded_checkpoint_dir_name, + unsharded_checkpoint_directory_in_source, + ) + continue + + dest_directory = os.path.join(checkpoints_dest_dir, f"{sharded_checkpoint_dir_name}-unsharded") + dest_storage = _get_storage_adapter_for_path(dest_directory) + if dest_storage.is_dir(dest_directory): + log.info( + "Unsharded directory already exists for %s at destination %s, skipping", + sharded_checkpoint_dir_name, + dest_directory, + ) + continue + + if config.dry_run: + log.info("Would unshard sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory) + else: + log.info("Unsharding sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory) + _unshard_checkpoint(sharded_checkpoint_directory, dest_directory, run_dir_or_archive, config) + + def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: UnshardCheckpointsConfig): - raise NotImplementedError() + storage = _get_storage_adapter_for_path(run_path) + run_dir_or_archive = _format_dir_or_archive_path(storage, run_path) + _unshard_checkpoints(storage, run_dir_or_archive, checkpoints_dest_dir, config) def perform_operation(args: argparse.Namespace): From 79d9a2fc00a7cd238b3a80ece87dce3e8acedf9d Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Dec 2023 10:34:36 -0800 Subject: [PATCH 04/13] Remove unused argument --- scripts/storage_cleaner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 8c8cfd569..7ea9e2cd8 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -752,7 +752,7 @@ def _get_sharded_checkpoint_dirs( return sharded_checkpoint_directories -def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str, run_dir_or_archive: str, unsharding_config: UnshardCheckpointsConfig): +def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str, unsharding_config: UnshardCheckpointsConfig): local_storage = LocalFileSystemAdapter() # Download checkpoint to a temp dir @@ -830,7 +830,7 @@ def _unshard_checkpoints( log.info("Would unshard sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory) else: log.info("Unsharding sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory) - _unshard_checkpoint(sharded_checkpoint_directory, dest_directory, run_dir_or_archive, config) + _unshard_checkpoint(sharded_checkpoint_directory, dest_directory, config) def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: UnshardCheckpointsConfig): From 7538690acb7dad9063d982f626f226ebfbe5d1a9 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Dec 2023 10:36:34 -0800 Subject: [PATCH 05/13] Make unsharding handle unarchiving --- scripts/storage_cleaner.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 7ea9e2cd8..d56d9e850 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -1,5 +1,6 @@ import argparse import logging +import os import re import shutil import subprocess @@ -798,23 +799,25 @@ def _unshard_checkpoints( ): log.info("Starting unsharding checkpoints of run directory or archive %s", run_dir_or_archive) + run_dir = _unarchive_if_archive(run_dir_or_archive, run_storage) + run_dir_storage = _get_storage_adapter_for_path(run_dir) + sharded_checkpoint_directories = _get_sharded_checkpoint_dirs( - run_storage, run_dir_or_archive, config.latest_checkpoint_only + run_dir_storage, run_dir, config.latest_checkpoint_only ) for sharded_checkpoint_directory in sharded_checkpoint_directories: sharded_checkpoint_dir_name = Path(sharded_checkpoint_directory).name - if run_storage.is_dir(run_dir_or_archive): - unsharded_checkpoint_directory_in_source = os.path.join( - run_dir_or_archive, f"{sharded_checkpoint_dir_name}-unsharded" + unsharded_checkpoint_directory_in_source = os.path.join( + run_dir, f"{sharded_checkpoint_dir_name}-unsharded" + ) + if run_dir_storage.is_dir(unsharded_checkpoint_directory_in_source): + log.info( + "Unsharded directory already exists for %s at source %s, skipping", + sharded_checkpoint_dir_name, + unsharded_checkpoint_directory_in_source, ) - if run_storage.is_dir(unsharded_checkpoint_directory_in_source): - log.info( - "Unsharded directory already exists for %s at source %s, skipping", - sharded_checkpoint_dir_name, - unsharded_checkpoint_directory_in_source, - ) - continue + continue dest_directory = os.path.join(checkpoints_dest_dir, f"{sharded_checkpoint_dir_name}-unsharded") dest_storage = _get_storage_adapter_for_path(dest_directory) From 62e3dfe986e6124bb1f17d8ee77715bbc887beeb Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Dec 2023 10:37:28 -0800 Subject: [PATCH 06/13] Remove outdated reference to _get_run_entries --- scripts/storage_cleaner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index d56d9e850..d683e0546 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -735,9 +735,10 @@ def _get_checkpoint_number(checkpoint_dir: str) -> int: def _get_sharded_checkpoint_dirs( - storage: StorageAdapter, run_dir_or_archive: str, latest_checkpoint_only: bool + run_dir_storage: StorageAdapter, run_dir: str, run_dir_or_archive: str, latest_checkpoint_only: bool ) -> List[str]: - run_subdirectories = _get_run_entries(run_dir_or_archive, storage, full_path=True) + run_subdir_names = run_dir_storage.list_dirs(run_dir) + run_subdirectories = list(map(lambda dir_name: os.path.join(run_dir, dir_name), run_subdir_names)) sharded_checkpoint_directories = list( filter(_is_sharded_checkpoint_dir, run_subdirectories) ) @@ -803,7 +804,7 @@ def _unshard_checkpoints( run_dir_storage = _get_storage_adapter_for_path(run_dir) sharded_checkpoint_directories = _get_sharded_checkpoint_dirs( - run_dir_storage, run_dir, config.latest_checkpoint_only + run_dir_storage, run_dir, run_dir_or_archive, config.latest_checkpoint_only ) for sharded_checkpoint_directory in sharded_checkpoint_directories: sharded_checkpoint_dir_name = Path(sharded_checkpoint_directory).name From 5c831c8e92e0b23884d1b47a2195ef87b599dcf7 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Dec 2023 10:50:55 -0800 Subject: [PATCH 07/13] Call unsharder directly rather than as a subprocess --- scripts/storage_cleaner.py | 18 ++++++++---------- scripts/unshard.py | 10 +++++++--- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index d683e0546..9a09d23fd 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -3,7 +3,6 @@ import os import re import shutil -import subprocess import tempfile from abc import ABC, abstractmethod from argparse import ArgumentParser, _SubParsersAction @@ -24,6 +23,7 @@ from olmo import util from olmo.aliases import PathOrStr +from scripts.unshard import unshard log = logging.getLogger(__name__) @@ -754,7 +754,7 @@ def _get_sharded_checkpoint_dirs( return sharded_checkpoint_directories -def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str, unsharding_config: UnshardCheckpointsConfig): +def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str): local_storage = LocalFileSystemAdapter() # Download checkpoint to a temp dir @@ -766,16 +766,14 @@ def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str, unsharding_c sharding_output_dir: str sharding_output_dir = local_storage.create_temp_dir() - result = subprocess.run( - ["python", str(unsharding_config.unshard_script_path), sharding_input_dir, sharding_output_dir], - check=False, - ) - if result.returncode != 0: + try: + unshard(sharding_input_dir, sharding_output_dir) + except RuntimeError as e: log.error( - "Unsharding from %s to %s failed with error code %d", + "Unsharding from %s to %s failed with exception: %s", sharding_input_dir, sharding_output_dir, - result.returncode, + e, ) local_storage.delete_path(sharding_output_dir) @@ -834,7 +832,7 @@ def _unshard_checkpoints( log.info("Would unshard sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory) else: log.info("Unsharding sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory) - _unshard_checkpoint(sharded_checkpoint_directory, dest_directory, config) + _unshard_checkpoint(sharded_checkpoint_directory, dest_directory) def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: UnshardCheckpointsConfig): diff --git a/scripts/unshard.py b/scripts/unshard.py index 8e285c53a..e9f4071d2 100644 --- a/scripts/unshard.py +++ b/scripts/unshard.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -def main( +def unshard( input_dir: Union[str, Path], output_dir: Union[str, Path], sharded_checkpoint_type: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy, @@ -65,7 +65,7 @@ def main( shutil.copy(input_dir / "config.yaml", output_dir) -if __name__ == "__main__": +def main(): import argparse parser = argparse.ArgumentParser(prog="unshard.py", description="Unshard sharded checkpoints on CPU") @@ -84,9 +84,13 @@ def main( args = parser.parse_args() logging.basicConfig(level=logging.INFO) - main( + unshard( args.input_dir, args.output_dir, sharded_checkpoint_type=args.type, model_only=args.model_only, ) + + +if __name__ == "__main__": + main() From 6d11a65806c49d28b833e3f7012ba8cf25287842 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Dec 2023 10:51:53 -0800 Subject: [PATCH 08/13] Clean up temp unsharding dirs after use --- scripts/storage_cleaner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 9a09d23fd..1867229d7 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -776,6 +776,7 @@ def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str): e, ) + local_storage.delete_path(sharding_input_dir) local_storage.delete_path(sharding_output_dir) return @@ -789,6 +790,9 @@ def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str): dest_storage = _get_storage_adapter_for_path(dest_dir) dest_storage.upload(sharding_output_dir, dest_dir) + local_storage.delete_path(sharding_input_dir) + local_storage.delete_path(sharding_output_dir) + def _unshard_checkpoints( run_storage: StorageAdapter, From 5a86649485985876426baa8180df9f327489076f Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Dec 2023 10:56:49 -0800 Subject: [PATCH 09/13] Run ruff --- scripts/storage_cleaner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 610853522..489c2cc42 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -737,9 +737,7 @@ def _get_sharded_checkpoint_dirs( ) -> List[str]: run_subdir_names = run_dir_storage.list_dirs(run_dir) run_subdirectories = list(map(lambda dir_name: os.path.join(run_dir, dir_name), run_subdir_names)) - sharded_checkpoint_directories = list( - filter(_is_sharded_checkpoint_dir, run_subdirectories) - ) + sharded_checkpoint_directories = list(filter(_is_sharded_checkpoint_dir, run_subdirectories)) if latest_checkpoint_only: latest_checkpoint_directory = max(sharded_checkpoint_directories, default=None, key=_get_checkpoint_number) @@ -747,7 +745,9 @@ def _get_sharded_checkpoint_dirs( [latest_checkpoint_directory] if latest_checkpoint_directory is not None else [] ) - log.info("Found %d sharded checkpoint directories for %s", len(sharded_checkpoint_directories), run_dir_or_archive) + log.info( + "Found %d sharded checkpoint directories for %s", len(sharded_checkpoint_directories), run_dir_or_archive + ) return sharded_checkpoint_directories From 75884f238b78fea343cce99d10a44f8a78d82d43 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Dec 2023 11:18:26 -0800 Subject: [PATCH 10/13] Add unsharding script logic directly into storage cleaner --- scripts/storage_cleaner.py | 39 ++++++++++++++++++++++++++++++++++++-- scripts/unshard.py | 10 +++------- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 489c2cc42..219c2ec6c 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -16,6 +16,7 @@ import boto3.session import botocore.exceptions as boto_exceptions import google.cloud.storage as gcs +import torch from cached_path import add_scheme_client, cached_path, set_cache_dir from cached_path.schemes import S3Client from google.api_core.exceptions import NotFound @@ -23,7 +24,8 @@ from olmo import util from olmo.aliases import PathOrStr -from scripts.unshard import unshard +from olmo.checkpoint import Checkpointer, LocalShardedCheckpointer, TorchLegacyShardedCheckpointer +from olmo.config import ShardedCheckpointerType, TrainConfig log = logging.getLogger(__name__) @@ -765,7 +767,19 @@ def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str): sharding_output_dir = local_storage.create_temp_dir() try: - unshard(sharding_input_dir, sharding_output_dir) + config = TrainConfig.load(Path(sharding_input_dir) / "config.yaml", validate_paths=False) + sharded_checkpoint_type = config.sharded_checkpointer + checkpointer: Checkpointer + if sharded_checkpoint_type == ShardedCheckpointerType.torch_legacy: + checkpointer = TorchLegacyShardedCheckpointer(config) + elif sharded_checkpoint_type == ShardedCheckpointerType.local: + checkpointer = LocalShardedCheckpointer(config) + else: + raise NotImplementedError(sharded_checkpoint_type) + + model_state_dict, optim_state_dict, trainer_state_dict = checkpointer.unshard_checkpoint( + sharding_input_dir + ) except RuntimeError as e: log.error( "Unsharding from %s to %s failed with exception: %s", @@ -778,6 +792,27 @@ def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str): local_storage.delete_path(sharding_output_dir) return + # model + model_output = str(Path(sharding_output_dir) / "model.pt") + log.info("Saving model state to %s", model_output) + torch.save(model_state_dict, model_output) + del model_state_dict + + # optimizer + optim_output = str(Path(sharding_output_dir) / "optim.pt") + log.info("Saving optimizer state to %s", optim_output) + torch.save(optim_state_dict, optim_output) + del optim_state_dict + + # trainer + train_output = str(Path(sharding_output_dir) / "train.pt") + log.info("Saving everything else to %s", train_output) + torch.save(trainer_state_dict, train_output) + del trainer_state_dict + + log.info("Copying config.yaml to %s", sharding_output_dir) + shutil.copy(Path(sharding_input_dir) / "config.yaml", sharding_output_dir) + log.info( "Successfully unsharded from %s to %s, starting upload to %s", sharding_input_dir, diff --git a/scripts/unshard.py b/scripts/unshard.py index e9f4071d2..8e285c53a 100644 --- a/scripts/unshard.py +++ b/scripts/unshard.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -def unshard( +def main( input_dir: Union[str, Path], output_dir: Union[str, Path], sharded_checkpoint_type: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy, @@ -65,7 +65,7 @@ def unshard( shutil.copy(input_dir / "config.yaml", output_dir) -def main(): +if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(prog="unshard.py", description="Unshard sharded checkpoints on CPU") @@ -84,13 +84,9 @@ def main(): args = parser.parse_args() logging.basicConfig(level=logging.INFO) - unshard( + main( args.input_dir, args.output_dir, sharded_checkpoint_type=args.type, model_only=args.model_only, ) - - -if __name__ == "__main__": - main() From e6483ae69faf4516603df271c56b9797586778b6 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Dec 2023 11:19:01 -0800 Subject: [PATCH 11/13] Remove redundant top level dirs from archives --- scripts/storage_cleaner.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 219c2ec6c..92a2bdd7a 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -662,12 +662,21 @@ def _unarchive_if_archive(dir_or_archive: str, storage: StorageAdapter) -> str: unarchived_dir = cached_path(dir_or_archive, extract_archive=True) assert unarchived_dir != Path(dir_or_archive) + # The unarchived file could have a redundant top-level directory. If the top-level + # directory has only a directory, we should return that directory instead. + unarchived_dir_storage = _get_storage_adapter_for_path(str(unarchived_dir)) + unarchived_dir_entries = unarchived_dir_storage.list_entries(str(unarchived_dir)) + if len(unarchived_dir_entries) == 1: + unarchived_entry_path = unarchived_dir / unarchived_dir_entries[0] + if unarchived_dir_storage.is_dir(str(unarchived_entry_path)): + return str(unarchived_entry_path) + return str(unarchived_dir) if storage.is_dir(dir_or_archive): return dir_or_archive - raise ValueError(f"Run dir or archive {dir_or_archive} is not a valid archive file or directory") + raise ValueError(f"Dir or archive {dir_or_archive} is not a valid archive file or directory") def _should_delete_run(storage: StorageAdapter, run_dir_or_archive: str, config: DeleteBadRunsConfig) -> bool: From 0b0e42ce2918b033536c28c3e8c2602b8bf57a64 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Dec 2023 11:21:35 -0800 Subject: [PATCH 12/13] Remove unused script path option --- scripts/storage_cleaner.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 92a2bdd7a..682cc5ffb 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -32,7 +32,6 @@ CONFIG_YAML: str = "config.yaml" DEFAULT_DELETE_MAX_ARCHIVE_SIZE: float = 5 * 1024 * 1024 * 1024 # 5GB -UNSHARD_SCRIPT_PATH: str = "scripts/unshard.py" class CleaningOperations(Enum): @@ -604,7 +603,6 @@ class DeleteBadRunsConfig: @dataclass class UnshardCheckpointsConfig: dry_run: bool - unshard_script_path: Path latest_checkpoint_only: bool @@ -905,7 +903,6 @@ def perform_operation(args: argparse.Namespace): elif args.op == CleaningOperations.UNSHARD_CHECKPOINTS: unshard_checkpoints_config = UnshardCheckpointsConfig( dry_run=args.dry_run, - unshard_script_path=args.script_path, latest_checkpoint_only=args.latest_checkpoint_only, ) if args.run_path is not None: @@ -999,11 +996,6 @@ def _add_unsharding_subparser(subparsers: _SubParsersAction): action="store_true", help="If set, only the latest checkpoint of each run (if sharded) is unsharded.", ) - unsharding_runs_parser.add_argument( - "--script_path", - default=UNSHARD_SCRIPT_PATH, - help=f"Path of the unsharder script. Set to `{UNSHARD_SCRIPT_PATH}` by default.", - ) def get_parser() -> ArgumentParser: From 6168c87f107a74f9270e270df0f5246545ef8029 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Dec 2023 11:28:26 -0800 Subject: [PATCH 13/13] Run isort --- scripts/storage_cleaner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 682cc5ffb..aa8184605 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -24,7 +24,11 @@ from olmo import util from olmo.aliases import PathOrStr -from olmo.checkpoint import Checkpointer, LocalShardedCheckpointer, TorchLegacyShardedCheckpointer +from olmo.checkpoint import ( + Checkpointer, + LocalShardedCheckpointer, + TorchLegacyShardedCheckpointer, +) from olmo.config import ShardedCheckpointerType, TrainConfig log = logging.getLogger(__name__)