Skip to content

Commit

Permalink
Merge pull request #381 from allenai/shanea/storage-cleaner-unsharding-2
Browse files Browse the repository at this point in the history
[Storage Cleaner] Add unsharding to storage cleaner
  • Loading branch information
2015aroras authored Dec 11, 2023
2 parents 9cc7154 + 6c6ede1 commit 1ede949
Showing 1 changed file with 216 additions and 3 deletions.
219 changes: 216 additions & 3 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import logging
import os
import re
import shutil
import tempfile
Expand All @@ -15,22 +16,31 @@
import boto3.session
import botocore.exceptions as boto_exceptions
import google.cloud.storage as gcs
import torch
from cached_path import add_scheme_client, cached_path, set_cache_dir
from cached_path.schemes import S3Client
from google.api_core.exceptions import NotFound
from rich.progress import Progress, TaskID, track

from olmo import util
from olmo.aliases import PathOrStr
from olmo.checkpoint import (
Checkpointer,
LocalShardedCheckpointer,
TorchLegacyShardedCheckpointer,
)
from olmo.config import ShardedCheckpointerType, TrainConfig

log = logging.getLogger(__name__)


DEFAULT_MAX_ARCHIVE_SIZE: float = 5 * 1024 * 1024 * 1024 # 5GB
CONFIG_YAML: str = "config.yaml"
DEFAULT_DELETE_MAX_ARCHIVE_SIZE: float = 5 * 1024 * 1024 * 1024 # 5GB


class CleaningOperations(Enum):
DELETE_BAD_RUNS = auto()
UNSHARD_CHECKPOINTS = auto()


class StorageType(util.StrEnum):
Expand Down Expand Up @@ -594,6 +604,12 @@ class DeleteBadRunsConfig:
max_archive_size: Optional[int]


@dataclass
class UnshardCheckpointsConfig:
dry_run: bool
latest_checkpoint_only: bool


def _get_storage_adapter_for_path(path: str) -> StorageAdapter:
storage_type = StorageAdapter.get_storage_type_for_path(path)
return StorageAdapter.create_storage_adapter(storage_type)
Expand Down Expand Up @@ -648,12 +664,21 @@ def _unarchive_if_archive(dir_or_archive: str, storage: StorageAdapter) -> str:
unarchived_dir = cached_path(dir_or_archive, extract_archive=True)
assert unarchived_dir != Path(dir_or_archive)

# The unarchived file could have a redundant top-level directory. If the top-level
# directory has only a directory, we should return that directory instead.
unarchived_dir_storage = _get_storage_adapter_for_path(str(unarchived_dir))
unarchived_dir_entries = unarchived_dir_storage.list_entries(str(unarchived_dir))
if len(unarchived_dir_entries) == 1:
unarchived_entry_path = unarchived_dir / unarchived_dir_entries[0]
if unarchived_dir_storage.is_dir(str(unarchived_entry_path)):
return str(unarchived_entry_path)

return str(unarchived_dir)

if storage.is_dir(dir_or_archive):
return dir_or_archive

raise ValueError(f"Run dir or archive {dir_or_archive} is not a valid archive file or directory")
raise ValueError(f"Dir or archive {dir_or_archive} is not a valid archive file or directory")


def _should_delete_run(storage: StorageAdapter, run_dir_or_archive: str, config: DeleteBadRunsConfig) -> bool:
Expand Down Expand Up @@ -705,6 +730,165 @@ def delete_bad_runs(run_paths: List[str], config: DeleteBadRunsConfig):
_delete_if_bad_run(storage, run_path, config)


def _is_sharded_checkpoint_dir(directory: str) -> bool:
storage = _get_storage_adapter_for_path(directory)
return storage.is_dir(directory) and re.match(r"step\d+$", Path(directory).name) is not None


def _get_checkpoint_number(checkpoint_dir: str) -> int:
checkpoint_dir_name = Path(checkpoint_dir).name
checkpoint_dir_name = checkpoint_dir_name.removesuffix("-unsharded")
match = re.match(r"step(\d+)$", checkpoint_dir_name)
if match is None:
raise ValueError(f"Failed to find checkpoint number for dir {checkpoint_dir}")

return int(match.group(1))


def _get_sharded_checkpoint_dirs(
run_dir_storage: StorageAdapter, run_dir: str, run_dir_or_archive: str, latest_checkpoint_only: bool
) -> List[str]:
run_subdir_names = run_dir_storage.list_dirs(run_dir)
run_subdirectories = list(map(lambda dir_name: os.path.join(run_dir, dir_name), run_subdir_names))
sharded_checkpoint_directories = list(filter(_is_sharded_checkpoint_dir, run_subdirectories))

if latest_checkpoint_only:
latest_checkpoint_directory = max(sharded_checkpoint_directories, default=None, key=_get_checkpoint_number)
sharded_checkpoint_directories = (
[latest_checkpoint_directory] if latest_checkpoint_directory is not None else []
)

log.info(
"Found %d sharded checkpoint directories for %s", len(sharded_checkpoint_directories), run_dir_or_archive
)

return sharded_checkpoint_directories


def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str):
local_storage = LocalFileSystemAdapter()

# Download checkpoint to a temp dir
sharding_input_dir = local_storage.create_temp_dir()
src_storage = _get_storage_adapter_for_path(sharded_checkpoint_dir)
src_storage.download_folder(sharded_checkpoint_dir, sharding_input_dir)

# Set unsharder output to a temp dir
sharding_output_dir: str
sharding_output_dir = local_storage.create_temp_dir()

try:
config = TrainConfig.load(Path(sharding_input_dir) / "config.yaml", validate_paths=False)
sharded_checkpoint_type = config.sharded_checkpointer
checkpointer: Checkpointer
if sharded_checkpoint_type == ShardedCheckpointerType.torch_legacy:
checkpointer = TorchLegacyShardedCheckpointer(config)
elif sharded_checkpoint_type == ShardedCheckpointerType.local:
checkpointer = LocalShardedCheckpointer(config)
else:
raise NotImplementedError(sharded_checkpoint_type)

model_state_dict, optim_state_dict, trainer_state_dict = checkpointer.unshard_checkpoint(
sharding_input_dir
)
except RuntimeError as e:
log.error(
"Unsharding from %s to %s failed with exception: %s",
sharding_input_dir,
sharding_output_dir,
e,
)

local_storage.delete_path(sharding_input_dir)
local_storage.delete_path(sharding_output_dir)
return

# model
model_output = str(Path(sharding_output_dir) / "model.pt")
log.info("Saving model state to %s", model_output)
torch.save(model_state_dict, model_output)
del model_state_dict

# optimizer
optim_output = str(Path(sharding_output_dir) / "optim.pt")
log.info("Saving optimizer state to %s", optim_output)
torch.save(optim_state_dict, optim_output)
del optim_state_dict

# trainer
train_output = str(Path(sharding_output_dir) / "train.pt")
log.info("Saving everything else to %s", train_output)
torch.save(trainer_state_dict, train_output)
del trainer_state_dict

log.info("Copying config.yaml to %s", sharding_output_dir)
shutil.copy(Path(sharding_input_dir) / "config.yaml", sharding_output_dir)

log.info(
"Successfully unsharded from %s to %s, starting upload to %s",
sharding_input_dir,
sharding_output_dir,
dest_dir,
)

dest_storage = _get_storage_adapter_for_path(dest_dir)
dest_storage.upload(sharding_output_dir, dest_dir)

local_storage.delete_path(sharding_input_dir)
local_storage.delete_path(sharding_output_dir)


def _unshard_checkpoints(
run_storage: StorageAdapter,
run_dir_or_archive: str,
checkpoints_dest_dir: str,
config: UnshardCheckpointsConfig,
):
log.info("Starting unsharding checkpoints of run directory or archive %s", run_dir_or_archive)

run_dir = _unarchive_if_archive(run_dir_or_archive, run_storage)
run_dir_storage = _get_storage_adapter_for_path(run_dir)

sharded_checkpoint_directories = _get_sharded_checkpoint_dirs(
run_dir_storage, run_dir, run_dir_or_archive, config.latest_checkpoint_only
)
for sharded_checkpoint_directory in sharded_checkpoint_directories:
sharded_checkpoint_dir_name = Path(sharded_checkpoint_directory).name

unsharded_checkpoint_directory_in_source = os.path.join(
run_dir, f"{sharded_checkpoint_dir_name}-unsharded"
)
if run_dir_storage.is_dir(unsharded_checkpoint_directory_in_source):
log.info(
"Unsharded directory already exists for %s at source %s, skipping",
sharded_checkpoint_dir_name,
unsharded_checkpoint_directory_in_source,
)
continue

dest_directory = os.path.join(checkpoints_dest_dir, f"{sharded_checkpoint_dir_name}-unsharded")
dest_storage = _get_storage_adapter_for_path(dest_directory)
if dest_storage.is_dir(dest_directory):
log.info(
"Unsharded directory already exists for %s at destination %s, skipping",
sharded_checkpoint_dir_name,
dest_directory,
)
continue

if config.dry_run:
log.info("Would unshard sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory)
else:
log.info("Unsharding sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory)
_unshard_checkpoint(sharded_checkpoint_directory, dest_directory)


def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: UnshardCheckpointsConfig):
storage = _get_storage_adapter_for_path(run_path)
run_dir_or_archive = _format_dir_or_archive_path(storage, run_path)
_unshard_checkpoints(storage, run_dir_or_archive, checkpoints_dest_dir, config)


def perform_operation(args: argparse.Namespace):
if args.dry_run:
log.info("Dry run, no irreversible actions will be taken")
Expand All @@ -720,6 +904,15 @@ def perform_operation(args: argparse.Namespace):
delete_bad_runs(args.run_paths, delete_bad_runs_config)
else:
raise ValueError("Run paths not provided for run cleaning")
elif args.op == CleaningOperations.UNSHARD_CHECKPOINTS:
unshard_checkpoints_config = UnshardCheckpointsConfig(
dry_run=args.dry_run,
latest_checkpoint_only=args.latest_checkpoint_only,
)
if args.run_path is not None:
unshard_run_checkpoints(args.run_path, args.dest_dir, unshard_checkpoints_config)
else:
raise ValueError("Run path not provided for unsharding")
else:
raise NotImplementedError(args.op)

Expand Down Expand Up @@ -785,11 +978,30 @@ def _add_delete_subparser(subparsers: _SubParsersAction):

delete_runs_parser.add_argument(
"--max_archive_size",
default=DEFAULT_MAX_ARCHIVE_SIZE,
default=DEFAULT_DELETE_MAX_ARCHIVE_SIZE,
help="Max size archive files to consider for deletion (in bytes). Any archive larger than this is ignored/not deleted.",
)


def _add_unsharding_subparser(subparsers: _SubParsersAction):
unsharding_runs_parser: ArgumentParser = subparsers.add_parser("unshard", help="unshard checkpoints of a run")
unsharding_runs_parser.set_defaults(op=CleaningOperations.UNSHARD_CHECKPOINTS)

unsharding_runs_parser.add_argument(
"run_path",
help="Path to run directory or archive containing checkpoints to unshard.",
)
unsharding_runs_parser.add_argument(
"dest_dir",
help="Path to directory where the run's unsharded checkpoints should be output (only the unsharded checkpoints are stored).",
)
unsharding_runs_parser.add_argument(
"--latest_checkpoint_only",
action="store_true",
help="If set, only the latest checkpoint of each run (if sharded) is unsharded.",
)


def get_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument(
Expand All @@ -805,6 +1017,7 @@ def get_parser() -> ArgumentParser:

subparsers = parser.add_subparsers(dest="command", help="Cleaning commands", required=True)
_add_delete_subparser(subparsers)
_add_unsharding_subparser(subparsers)

return parser

Expand Down

0 comments on commit 1ede949

Please sign in to comment.