Skip to content

Commit

Permalink
Merge pull request #742 from allenai/GoogleStorage
Browse files Browse the repository at this point in the history
Improved support for Google Storage
  • Loading branch information
dirkgr authored Nov 7, 2024
2 parents afd728f + 67132f6 commit 31c385f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `constant_with_warmup` learning rate schedule
- `one_in_eight` configuration for activation checkpointing
- New tokenizer in the source instead of from huggingface
- Improved support for GCS


## [v0.5.1](https://github.com/allenai/OLMo/releases/tag/v0.5.1) - 2024-10-17
Expand Down
75 changes: 70 additions & 5 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import boto3
import botocore.exceptions as boto_exceptions
import datasets
import requests
import rich
from botocore.config import Config
from cached_path.schemes import SchemeClient, add_scheme_client
from google.api_core.retry import Retry as GCSRetry
from google.api_core.retry import if_transient_error as gcs_is_transient_error
from rich.console import Console, ConsoleRenderable
from rich.highlighter import NullHighlighter
from rich.progress import Progress
Expand Down Expand Up @@ -393,7 +396,7 @@ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:

parsed = urlparse(str(dir))
if parsed.scheme == "gs":
raise NotImplementedError
return _gcs_find_latest_checkpoint(parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme in ("s3", "r2", "weka"):
return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme == "file":
Expand All @@ -416,6 +419,18 @@ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
return latest_checkpoint


# Google Storage API is unhinged and requires you to specify the retry policy on every single call you make.
def _gcs_is_retriable(exception: Exception) -> bool:
if gcs_is_transient_error(exception):
return True
if isinstance(exception, requests.exceptions.ReadTimeout):
return True
return False


_gcs_retry = GCSRetry(predicate=_gcs_is_retriable, initial=1.0, maximum=10.0, multiplier=2.0, timeout=600.0)


def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
from google.cloud import storage as gcs

Expand All @@ -424,7 +439,7 @@ def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool =
blob = bucket.blob(key)
if not save_overwrite and blob.exists():
raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
blob.upload_from_filename(source)
blob.upload_from_filename(source, retry=_gcs_retry)


def _gcs_file_size(bucket_name: str, key: str) -> int:
Expand All @@ -435,7 +450,7 @@ def _gcs_file_size(bucket_name: str, key: str) -> int:
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
try:
blob.reload()
blob.reload(retry=_gcs_retry)
except NotFound:
raise FileNotFoundError(f"gs://{bucket_name}/{key}")
assert blob.size is not None
Expand All @@ -450,10 +465,60 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
try:
blob.reload()
return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1, retry=_gcs_retry)
except NotFound:
raise FileNotFoundError(f"gs://{bucket_name}/{key}")
return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1)


@cache
def _get_gcs_client():
from google.cloud import storage as gcs

return gcs.Client()


def _gcs_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]:
if not prefix.endswith("/"):
prefix = f"{prefix}/"

storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
suffix = "/config.yaml"
latest_step: Optional[int] = None
latest_checkpoint: Optional[str] = None
for blob in bucket.list_blobs(prefix=prefix, match_glob=f"**{suffix}"):
# Disregard checkpoints that have an empty config file.
if blob.size <= 0:
continue

name = blob.name[len(prefix) : -len(suffix)]

if "/" in name:
# We're not considering checkpoints in subdirectories.
continue

if not name.startswith("step"):
continue
name = name[4:]

if name.endswith("-unsharded"):
name = name[: -len("-unsharded")]

try:
step = int(name)
except ValueError:
continue

# we prefer sharded checkpoints to unsharded ones
if (
latest_step is None
or step > latest_step
or (step == latest_step and latest_checkpoint is not None and latest_checkpoint.endswith("-unsharded"))
):
latest_step = step
latest_checkpoint = f"gs://{bucket_name}/{blob.name[:-len(suffix)]}"

return latest_checkpoint


def _get_s3_profile_name(scheme: str) -> Optional[str]:
Expand Down

0 comments on commit 31c385f

Please sign in to comment.