Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage cleaner] Improve handling of temp directories and files #396

Merged
merged 9 commits into from
Dec 11, 2023
110 changes: 68 additions & 42 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def create_temp_file(self, suffix: Optional[str] = None) -> str:
self._temp_files.append(temp_file)
return temp_file.name

def create_temp_dir(self, suffix: Optional[str] = None) -> str:
temp_dir = tempfile.TemporaryDirectory(suffix=suffix)
def create_temp_dir(self, directory: Optional[str] = None, suffix: Optional[str] = None) -> str:
temp_dir = tempfile.TemporaryDirectory(dir=directory, suffix=suffix)
self._temp_dirs.append(temp_dir)
return temp_dir.name

Expand Down Expand Up @@ -597,16 +597,20 @@ def upload_callback(progress: Progress, upload_task: TaskID, bytes_uploaded: int


@dataclass
class DeleteBadRunsConfig:
class StorageCleanerConfig:
dry_run: bool
temp_dir: str


@dataclass
class DeleteBadRunsConfig(StorageCleanerConfig):
should_check_is_run: bool
ignore_non_runs: bool
max_archive_size: Optional[int]


@dataclass
class UnshardCheckpointsConfig:
dry_run: bool
class UnshardCheckpointsConfig(StorageCleanerConfig):
latest_checkpoint_only: bool


Expand Down Expand Up @@ -729,6 +733,11 @@ def delete_bad_runs(run_paths: List[str], config: DeleteBadRunsConfig):
log.info("Starting to check if run %s should be deleted", run_path)
_delete_if_bad_run(storage, run_path, config)

# Delete temp dir after each run to avoid storage bloat
if Path(config.temp_dir).is_dir():
log.info("Deleting temp dir %s", config.temp_dir)
shutil.rmtree(config.temp_dir)


def _is_sharded_checkpoint_dir(directory: str) -> bool:
storage = _get_storage_adapter_for_path(directory)
Expand Down Expand Up @@ -788,19 +797,21 @@ def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str):
log.warning("Cannot find training config to add to checkpoint %s", local_checkpoint_dir)


def _unshard_checkpoint(sharded_checkpoint_dir: str, dest_dir: str, run_dir: str):
def _unshard_checkpoint(
sharded_checkpoint_dir: str, dest_dir: str, run_dir: str, unsharding_config: UnshardCheckpointsConfig
):
local_storage = LocalFileSystemAdapter()

# Download checkpoint to a temp dir
sharding_input_dir = local_storage.create_temp_dir()
sharding_input_dir = local_storage.create_temp_dir(directory=unsharding_config.temp_dir)
src_storage = _get_storage_adapter_for_path(sharded_checkpoint_dir)
src_storage.download_folder(sharded_checkpoint_dir, sharding_input_dir)

_add_training_config_to_checkpoint(sharding_input_dir, run_dir)

# Set unsharder output to a temp dir
sharding_output_dir: str
sharding_output_dir = local_storage.create_temp_dir()
sharding_output_dir = local_storage.create_temp_dir(directory=unsharding_config.temp_dir)

try:
config = TrainConfig.load(Path(sharding_input_dir) / "config.yaml", validate_paths=False)
Expand Down Expand Up @@ -905,7 +916,7 @@ def _unshard_checkpoints(
log.info("Would unshard sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory)
else:
log.info("Unsharding sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory)
_unshard_checkpoint(sharded_checkpoint_directory, dest_directory, run_dir)
_unshard_checkpoint(sharded_checkpoint_directory, dest_directory, run_dir, config)


def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: UnshardCheckpointsConfig):
Expand All @@ -914,34 +925,6 @@ def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: Un
_unshard_checkpoints(storage, run_dir_or_archive, checkpoints_dest_dir, config)


def perform_operation(args: argparse.Namespace):
if args.dry_run:
log.info("Dry run, no irreversible actions will be taken")

if args.op == CleaningOperations.DELETE_BAD_RUNS:
delete_bad_runs_config = DeleteBadRunsConfig(
dry_run=args.dry_run,
should_check_is_run=args.should_check_is_run,
ignore_non_runs=args.ignore_non_runs,
max_archive_size=args.max_archive_size,
)
if args.run_paths is not None:
delete_bad_runs(args.run_paths, delete_bad_runs_config)
else:
raise ValueError("Run paths not provided for run cleaning")
elif args.op == CleaningOperations.UNSHARD_CHECKPOINTS:
unshard_checkpoints_config = UnshardCheckpointsConfig(
dry_run=args.dry_run,
latest_checkpoint_only=args.latest_checkpoint_only,
)
if args.run_path is not None:
unshard_run_checkpoints(args.run_path, args.dest_dir, unshard_checkpoints_config)
else:
raise ValueError("Run path not provided for unsharding")
else:
raise NotImplementedError(args.op)


def _add_cached_path_s3_client():
class S3SchemeClient(S3Client):
"""
Expand Down Expand Up @@ -969,13 +952,57 @@ def __init__(self, resource: str) -> None:
add_scheme_client(S3SchemeClient)


def _setup_cached_path(args: argparse.Namespace):
if args.temp_dir is not None:
set_cache_dir(args.temp_dir)
def _setup_cached_path(temp_dir: str):
if temp_dir is not None:
set_cache_dir(temp_dir)

_add_cached_path_s3_client()


def perform_operation(args: argparse.Namespace):
if args.dry_run:
log.info("Dry run, no irreversible actions will be taken")

if (
args.temp_dir is not None
and StorageAdapter.get_storage_type_for_path(args.temp_dir) != StorageType.LOCAL_FS
):
raise ValueError("Temporary directory must be a local path")

temp_dir = tempfile.mkdtemp(dir=args.temp_dir)
_setup_cached_path(temp_dir)

try:
if args.op == CleaningOperations.DELETE_BAD_RUNS:
delete_bad_runs_config = DeleteBadRunsConfig(
dry_run=args.dry_run,
temp_dir=temp_dir,
should_check_is_run=args.should_check_is_run,
ignore_non_runs=args.ignore_non_runs,
max_archive_size=args.max_archive_size,
)
if args.run_paths is not None:
delete_bad_runs(args.run_paths, delete_bad_runs_config)
else:
raise ValueError("Run paths not provided for run cleaning")
elif args.op == CleaningOperations.UNSHARD_CHECKPOINTS:
unshard_checkpoints_config = UnshardCheckpointsConfig(
dry_run=args.dry_run,
temp_dir=temp_dir,
latest_checkpoint_only=args.latest_checkpoint_only,
)
if args.run_path is not None:
unshard_run_checkpoints(args.run_path, args.dest_dir, unshard_checkpoints_config)
else:
raise ValueError("Run path not provided for unsharding")
else:
raise NotImplementedError(args.op)
finally:
if Path(temp_dir).is_dir():
log.info("Deleting temp dir %s", temp_dir)
shutil.rmtree(temp_dir)


def _add_delete_subparser(subparsers: _SubParsersAction):
delete_runs_parser: ArgumentParser = subparsers.add_parser(
"clean", help="Delete bad runs (e.g. runs with no non-trivial checkpoints)"
Expand Down Expand Up @@ -1037,7 +1064,7 @@ def get_parser() -> ArgumentParser:
)
parser.add_argument(
"--temp_dir",
help="Directory where artifacts (e.g. unarchived directories) can be stored temporarily",
help="Local directory where artifacts (e.g. unarchived directories) can be stored temporarily",
)

subparsers = parser.add_subparsers(dest="command", help="Cleaning commands", required=True)
Expand All @@ -1051,7 +1078,6 @@ def main():
args = get_parser().parse_args()

util.prepare_cli_environment()
_setup_cached_path(args)
perform_operation(args)


Expand Down
Loading