Skip to content

Commit

Permalink
Merge pull request #483 from allenai/shanea/storage-cleaner-unshard-i…
Browse files Browse the repository at this point in the history
…mprovements

[Storage Cleaner] Unsharding improvements
  • Loading branch information
2015aroras authored Mar 6, 2024
2 parents 1d264e4 + ba96c6a commit fd3a57b
Showing 1 changed file with 70 additions and 8 deletions.
78 changes: 70 additions & 8 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import boto3.session
import botocore.exceptions as boto_exceptions
import google.cloud.storage as gcs
import omegaconf
import torch
import wandb
from boto3.s3.transfer import TransferConfig
Expand Down Expand Up @@ -622,6 +623,8 @@ class DeleteBadRunsConfig(StorageCleanerConfig):
@dataclass
class UnshardCheckpointsConfig(StorageCleanerConfig):
latest_checkpoint_only: bool
delete_sharded_checkpoints: bool
checkpoint_num: Optional[int]


@dataclass
Expand Down Expand Up @@ -765,9 +768,13 @@ def delete_bad_runs(run_paths: List[str], config: DeleteBadRunsConfig):
shutil.rmtree(config.temp_dir)


def _is_sharded_checkpoint_dir(directory: str) -> bool:
def _is_checkpoint_dir(directory: str) -> bool:
storage = _get_storage_adapter_for_path(directory)
return storage.is_dir(directory) and re.match(r"step\d+$", Path(directory).name) is not None
return storage.is_dir(directory) and re.match(r"step\d+(-unsharded)?$", Path(directory).name) is not None


def _is_sharded_checkpoint_dir(directory: str) -> bool:
return _is_checkpoint_dir(directory) and re.match(r"step\d+$", Path(directory).name) is not None


def _get_checkpoint_number(checkpoint_dir: str) -> int:
Expand All @@ -781,17 +788,31 @@ def _get_checkpoint_number(checkpoint_dir: str) -> int:


def _get_sharded_checkpoint_dirs(
run_dir_storage: StorageAdapter, run_dir: str, run_dir_or_archive: str, latest_checkpoint_only: bool
run_dir_storage: StorageAdapter,
run_dir: str,
run_dir_or_archive: str,
latest_checkpoint_only: bool,
checkpoint_num: Optional[int] = None,
) -> List[str]:
run_subdir_names = run_dir_storage.list_dirs(run_dir)
run_subdirectories = list(map(lambda dir_name: os.path.join(run_dir, dir_name), run_subdir_names))
sharded_checkpoint_directories = list(filter(_is_sharded_checkpoint_dir, run_subdirectories))

if latest_checkpoint_only and checkpoint_num is not None:
raise ValueError("Cannot set both 'latest_checkpoint_only' and 'checkpoint_num'")

if latest_checkpoint_only:
latest_checkpoint_directory = max(sharded_checkpoint_directories, default=None, key=_get_checkpoint_number)
sharded_checkpoint_directories = (
[latest_checkpoint_directory] if latest_checkpoint_directory is not None else []
)
elif checkpoint_num is not None:
sharded_checkpoint_directories = [
sharded_checkpoint_dir
for sharded_checkpoint_dir in sharded_checkpoint_directories
if _get_checkpoint_number(sharded_checkpoint_dir) == checkpoint_num
]
assert len(sharded_checkpoint_directories) <= 1

log.info(
"Found %d sharded checkpoint directories for %s", len(sharded_checkpoint_directories), run_dir_or_archive
Expand Down Expand Up @@ -844,13 +865,29 @@ def _unshard_checkpoint(
sharding_output_dir = local_storage.create_temp_dir(directory=unsharding_config.temp_dir)

try:
config = TrainConfig.load(Path(sharding_input_dir) / "config.yaml", validate_paths=False)
sharded_checkpoint_type = config.sharded_checkpointer
# `TrainConfig` is not backwards-compatible with all older checkpoints, so
# we need to load the yaml directly.
raw_config = om.load(str(Path(sharding_input_dir) / "config.yaml"))
assert isinstance(raw_config, omegaconf.DictConfig)

sharded_checkpoint_type_str = raw_config.get("sharded_checkpointer", "torch_legacy")
if sharded_checkpoint_type_str == "legacy":
# At some point, the enum string for ShardedCheckpointerType.torch_legacy was "legacy"
sharded_checkpoint_type_str = "torch_legacy"

sharded_checkpoint_type = ShardedCheckpointerType[sharded_checkpoint_type_str]

# The ShardedCheckpointers require a `TrainConfig` to be passed in, but
# legacy configs are not all compatible with this class. None of the config
# settings are needed for unsharding, so we pass in a dummy config instead.
# This is a hack, but decoupling unsharding for checkpoint saving/loading
# seems like overkill.
dummy_config = TrainConfig.new()
checkpointer: Checkpointer
if sharded_checkpoint_type == ShardedCheckpointerType.torch_legacy:
checkpointer = TorchLegacyShardedCheckpointer(config)
checkpointer = TorchLegacyShardedCheckpointer(dummy_config)
elif sharded_checkpoint_type == ShardedCheckpointerType.local:
checkpointer = LocalShardedCheckpointer(config)
checkpointer = LocalShardedCheckpointer(dummy_config)
else:
raise NotImplementedError(sharded_checkpoint_type)

Expand Down Expand Up @@ -911,11 +948,14 @@ def _unshard_checkpoints(
):
log.info("Starting unsharding checkpoints of run directory or archive %s", run_dir_or_archive)

if config.delete_sharded_checkpoints and _is_archive(run_dir_or_archive, run_storage):
raise ValueError("Cannot delete sharded checkpoints of run archive files")

run_dir = _unarchive_if_archive(run_dir_or_archive, run_storage)
run_dir_storage = _get_storage_adapter_for_path(run_dir)

sharded_checkpoint_directories = _get_sharded_checkpoint_dirs(
run_dir_storage, run_dir, run_dir_or_archive, config.latest_checkpoint_only
run_dir_storage, run_dir, run_dir_or_archive, config.latest_checkpoint_only, config.checkpoint_num
)
for sharded_checkpoint_directory in sharded_checkpoint_directories:
sharded_checkpoint_dir_name = Path(sharded_checkpoint_directory).name
Expand Down Expand Up @@ -947,6 +987,14 @@ def _unshard_checkpoints(
log.info("Unsharding sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory)
_unshard_checkpoint(sharded_checkpoint_directory, dest_directory, run_dir, config)

if config.delete_sharded_checkpoints:
assert run_dir == run_dir_or_archive
if config.dry_run:
log.info("Would delete sharded checkpoint %s", sharded_checkpoint_directory)
else:
log.info("Deleting sharded checkpoint %s", sharded_checkpoint_directory)
run_dir_storage.delete_path(sharded_checkpoint_directory)


def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: UnshardCheckpointsConfig):
storage = _get_storage_adapter_for_path(run_path)
Expand Down Expand Up @@ -1252,6 +1300,8 @@ def perform_operation(args: argparse.Namespace):
dry_run=args.dry_run,
temp_dir=temp_dir,
latest_checkpoint_only=args.latest_checkpoint_only,
delete_sharded_checkpoints=args.delete_sharded_checkpoints,
checkpoint_num=args.checkpoint_num,
)
if args.run_path is not None:
unshard_run_checkpoints(args.run_path, args.dest_dir, unshard_checkpoints_config)
Expand Down Expand Up @@ -1327,6 +1377,18 @@ def _add_unsharding_subparser(subparsers: _SubParsersAction):
action="store_true",
help="If set, only the latest checkpoint of each run (if sharded) is unsharded.",
)
unsharding_runs_parser.add_argument(
"--delete_sharded",
dest="delete_sharded_checkpoints",
action="store_true",
help="If set, deletes sharded checkpoints after they have been successfully unsharded.",
)
unsharding_runs_parser.add_argument(
"--checkpoint_num",
type=int,
default=None,
help="If provided, unsharding is restricted to this checkpoint of the run.",
)


def _add_move_subparser(subparsers: _SubParsersAction):
Expand Down

0 comments on commit fd3a57b

Please sign in to comment.