Skip to content

Commit

Permalink
Merge pull request #407 from allenai/shanea/storage-cleaner-avoid-red…
Browse files Browse the repository at this point in the history
…undant-copy

[Storage cleaner] Remove redundant temp dir for copying from local dir
  • Loading branch information
2015aroras authored Dec 20, 2023
2 parents 53217d2 + b03f9c0 commit 3c51402
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def _get_sharded_checkpoint_dirs(
return sharded_checkpoint_directories


def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str):
def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str) -> bool:
max_train_config_size = 1 * 1024 * 1024 # 1MB

if not StorageAdapter.get_storage_type_for_path(local_checkpoint_dir) == StorageType.LOCAL_FS:
Expand All @@ -811,7 +811,7 @@ def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str):
checkpoint_storage = _get_storage_adapter_for_path(local_checkpoint_dir)
if CONFIG_YAML in checkpoint_storage.list_entries(local_checkpoint_dir, max_file_size=max_train_config_size):
# Config already exists in the checkpoint
return
return False

log.info("%s not found in %s, attempting to get it from %s", CONFIG_YAML, local_checkpoint_dir, run_dir)

Expand All @@ -820,22 +820,26 @@ def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str):
if run_storage.is_file(run_config_yaml_path):
local_config_yaml_path = cached_path(run_config_yaml_path)
shutil.copy(local_config_yaml_path, local_checkpoint_dir)
return
return True

log.warning("Cannot find training config to add to checkpoint %s", local_checkpoint_dir)
return False


def _unshard_checkpoint(
sharded_checkpoint_dir: str, dest_dir: str, run_dir: str, unsharding_config: UnshardCheckpointsConfig
):
local_storage = LocalFileSystemAdapter()

# Download checkpoint to a temp dir
sharding_input_dir = local_storage.create_temp_dir(directory=unsharding_config.temp_dir)
src_storage = _get_storage_adapter_for_path(sharded_checkpoint_dir)
src_storage.download_folder(sharded_checkpoint_dir, sharding_input_dir)
# Download checkpoint to a temp dir if it is in cloud storage
if StorageAdapter.get_storage_type_for_path(sharded_checkpoint_dir) != StorageType.LOCAL_FS:
sharding_input_dir = local_storage.create_temp_dir(directory=unsharding_config.temp_dir)
src_storage = _get_storage_adapter_for_path(sharded_checkpoint_dir)
src_storage.download_folder(sharded_checkpoint_dir, sharding_input_dir)
else:
sharding_input_dir = sharded_checkpoint_dir

_add_training_config_to_checkpoint(sharding_input_dir, run_dir)
training_config_added = _add_training_config_to_checkpoint(sharding_input_dir, run_dir)

# Set unsharder output to a temp dir
sharding_output_dir: str
Expand Down Expand Up @@ -863,7 +867,9 @@ def _unshard_checkpoint(
e,
)

local_storage.delete_path(sharding_input_dir)
if training_config_added:
local_storage.delete_path(str(Path(sharding_input_dir) / CONFIG_YAML))

local_storage.delete_path(sharding_output_dir)
return

Expand Down Expand Up @@ -898,9 +904,6 @@ def _unshard_checkpoint(
dest_storage = _get_storage_adapter_for_path(dest_dir)
dest_storage.upload(sharding_output_dir, dest_dir)

local_storage.delete_path(sharding_input_dir)
local_storage.delete_path(sharding_output_dir)


def _unshard_checkpoints(
run_storage: StorageAdapter,
Expand Down Expand Up @@ -1106,12 +1109,17 @@ def _copy(src_path: str, dest_path: str, temp_dir: str):
if src_is_file:
local_path = cached_path(src_path)
elif src_is_dir:
local_storage = LocalFileSystemAdapter()
local_path = local_storage.create_temp_dir(directory=temp_dir)
src_storage.download_folder(src_path, local_path)
if src_storage_type == StorageType.LOCAL_FS:
local_path = src_path
else:
local_storage = LocalFileSystemAdapter()
local_path = local_storage.create_temp_dir(directory=temp_dir)
log.info("Temporarily downloading %s to %s", src_path, local_path)
src_storage.download_folder(src_path, local_path)
else:
raise ValueError(f"Source path {src_path} does not correspond to a valid file or directory")

log.info("Uploading %s to %s", local_path, dest_path)
dest_storage.upload(local_path, dest_path)


Expand Down

0 comments on commit 3c51402

Please sign in to comment.