From a95ecaa45b03a8eb558de58891eb3baca6fdf81c Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 16 Oct 2024 16:46:38 -0700 Subject: [PATCH 01/11] More retries for GCS --- olmo/util.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/olmo/util.py b/olmo/util.py index 3697e86ce..66e29c1d8 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -456,6 +456,21 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1) +@cache +def _get_gcs_client(): + from google.cloud import storage as gcs + from google.api_core.retry import Retry + + return gcs.Client( + client_options={"retry": Retry( + initial=1.0, + maximum=10.0, + multiplier=2.0, + deadline=500.0 + )} + ) + + def _get_s3_profile_name(scheme: str) -> Optional[str]: if scheme == "s3": # For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set. From 71aae87920c0313a8c91bd42d0e17fb76fb7bd86 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 16 Oct 2024 17:01:01 -0700 Subject: [PATCH 02/11] Turns out Google APIs don't work that way. --- olmo/util.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index 66e29c1d8..6f603c14d 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -22,6 +22,7 @@ import rich from botocore.config import Config from cached_path.schemes import SchemeClient, add_scheme_client +from composer.utils import retry from rich.console import Console, ConsoleRenderable from rich.highlighter import NullHighlighter from rich.progress import Progress @@ -416,6 +417,16 @@ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]: return latest_checkpoint +# Google Storage API is unhinged and requires you to specify the retry policy on every single call you make. +from google.api_core.retry import Retry +_gcs_retry = Retry( + initial=1.0, + maximum=10.0, + multiplier=2.0, + deadline=500.0 +) + + def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False): from google.cloud import storage as gcs @@ -424,7 +435,7 @@ def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = blob = bucket.blob(key) if not save_overwrite and blob.exists(): raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.") - blob.upload_from_filename(source) + blob.upload_from_filename(source, retry=_gcs_retry) def _gcs_file_size(bucket_name: str, key: str) -> int: @@ -435,7 +446,7 @@ def _gcs_file_size(bucket_name: str, key: str) -> int: bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) try: - blob.reload() + blob.reload(retry=_gcs_retry) except NotFound: raise FileNotFoundError(f"gs://{bucket_name}/{key}") assert blob.size is not None @@ -450,25 +461,16 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) try: - blob.reload() + blob.reload(retry=_gcs_retry) except NotFound: raise FileNotFoundError(f"gs://{bucket_name}/{key}") - return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1) + return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1, retry=_gcs_retry) @cache def _get_gcs_client(): from google.cloud import storage as gcs - from google.api_core.retry import Retry - - return gcs.Client( - client_options={"retry": Retry( - initial=1.0, - maximum=10.0, - multiplier=2.0, - deadline=500.0 - )} - ) + return gcs.Client() def _get_s3_profile_name(scheme: str) -> Optional[str]: From ab67d52e0019c04cc72fb4904d0306d655e8e767 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 16 Oct 2024 21:37:40 -0700 Subject: [PATCH 03/11] Makes finding the latest checkpoint work on GCS --- olmo/util.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/olmo/util.py b/olmo/util.py index 6f603c14d..8a0ada418 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -394,7 +394,7 @@ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]: parsed = urlparse(str(dir)) if parsed.scheme == "gs": - raise NotImplementedError + return _gcs_find_latest_checkpoint(parsed.netloc, parsed.path.strip("/")) elif parsed.scheme in ("s3", "r2", "weka"): return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/")) elif parsed.scheme == "file": @@ -473,6 +473,53 @@ def _get_gcs_client(): return gcs.Client() +def _gcs_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]: + if not prefix.endswith("/"): + prefix = f"{prefix}/" + + storage_client = _get_gcs_client() + bucket = storage_client.bucket(bucket_name) + suffix = "/config.yaml" + latest_step: Optional[int] = None + latest_checkpoint: Optional[str] = None + for blob in bucket.list_blobs(prefix=prefix, match_glob=f"**{suffix}"): + # Disregard checkpoints that have an empty config file. + if blob.size <= 0: + continue + + name = blob.name[len(prefix):-len(suffix)] + + if "/" in name: + # We're not considering checkpoints in subdirectories. + continue + + if not name.startswith("step"): + continue + name = name[4:] + + if name.endswith("-unsharded"): + name = name[:-len("-unsharded")] + unsharded = True + else: + unsharded = False + + try: + step = int(name) + except ValueError: + continue + + # we prefer sharded checkpoints to unsharded ones + if ( + latest_step is None or + step > latest_step or + step == latest_step and latest_checkpoint.endswith("-unsharded") + ): + latest_step = step + latest_checkpoint = f"gs://{bucket_name}/{blob.name[:-len(suffix)]}" + + return latest_checkpoint + + def _get_s3_profile_name(scheme: str) -> Optional[str]: if scheme == "s3": # For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set. From 80f73b154926993bcb8d886b0f369576764e7374 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 16 Oct 2024 21:42:21 -0700 Subject: [PATCH 04/11] Remove unused code --- olmo/util.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index 8a0ada418..f5ebe47f5 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -499,9 +499,6 @@ def _gcs_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]: if name.endswith("-unsharded"): name = name[:-len("-unsharded")] - unsharded = True - else: - unsharded = False try: step = int(name) From c8d0ddc9bf0bc338dd8570cdc40dc13867deced0 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 16 Oct 2024 21:44:30 -0700 Subject: [PATCH 05/11] Fix imports --- olmo/util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index f5ebe47f5..db2d07387 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -22,7 +22,7 @@ import rich from botocore.config import Config from cached_path.schemes import SchemeClient, add_scheme_client -from composer.utils import retry +from google.api_core.retry import Retry as GCSRetry from rich.console import Console, ConsoleRenderable from rich.highlighter import NullHighlighter from rich.progress import Progress @@ -418,8 +418,7 @@ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]: # Google Storage API is unhinged and requires you to specify the retry policy on every single call you make. -from google.api_core.retry import Retry -_gcs_retry = Retry( +_gcs_retry = GCSRetry( initial=1.0, maximum=10.0, multiplier=2.0, From 2b86a33b98745d8a9745282ac601fee1b8fc1ddf Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 16 Oct 2024 21:45:16 -0700 Subject: [PATCH 06/11] Make the code less readable --- olmo/util.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index db2d07387..d40a38396 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -418,12 +418,7 @@ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]: # Google Storage API is unhinged and requires you to specify the retry policy on every single call you make. -_gcs_retry = GCSRetry( - initial=1.0, - maximum=10.0, - multiplier=2.0, - deadline=500.0 -) +_gcs_retry = GCSRetry(initial=1.0, maximum=10.0, multiplier=2.0, deadline=500.0) def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False): @@ -469,6 +464,7 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes @cache def _get_gcs_client(): from google.cloud import storage as gcs + return gcs.Client() @@ -486,7 +482,7 @@ def _gcs_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]: if blob.size <= 0: continue - name = blob.name[len(prefix):-len(suffix)] + name = blob.name[len(prefix) : -len(suffix)] if "/" in name: # We're not considering checkpoints in subdirectories. @@ -497,7 +493,7 @@ def _gcs_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]: name = name[4:] if name.endswith("-unsharded"): - name = name[:-len("-unsharded")] + name = name[: -len("-unsharded")] try: step = int(name) @@ -506,9 +502,10 @@ def _gcs_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]: # we prefer sharded checkpoints to unsharded ones if ( - latest_step is None or - step > latest_step or - step == latest_step and latest_checkpoint.endswith("-unsharded") + latest_step is None + or step > latest_step + or step == latest_step + and latest_checkpoint.endswith("-unsharded") ): latest_step = step latest_checkpoint = f"gs://{bucket_name}/{blob.name[:-len(suffix)]}" From f6fda1315a50f0da80444962c14ebb90206f51bd Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 16 Oct 2024 21:46:45 -0700 Subject: [PATCH 07/11] Make mypy happy --- olmo/util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index d40a38396..af342f16e 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -504,8 +504,7 @@ def _gcs_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]: if ( latest_step is None or step > latest_step - or step == latest_step - and latest_checkpoint.endswith("-unsharded") + or (step == latest_step and latest_checkpoint is not None and latest_checkpoint.endswith("-unsharded")) ): latest_step = step latest_checkpoint = f"gs://{bucket_name}/{blob.name[:-len(suffix)]}" From 005860e9ddb9ae2eda91d26b3bc30d89ae13913c Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Fri, 25 Oct 2024 13:51:48 -0700 Subject: [PATCH 08/11] More robust GCS downloads --- olmo/util.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index af342f16e..c2a7a9a0c 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -22,12 +22,13 @@ import rich from botocore.config import Config from cached_path.schemes import SchemeClient, add_scheme_client -from google.api_core.retry import Retry as GCSRetry +from google.api_core.retry import Retry as GCSRetry, if_transient_error as gcs_is_transient_error from rich.console import Console, ConsoleRenderable from rich.highlighter import NullHighlighter from rich.progress import Progress from rich.text import Text from rich.traceback import Traceback +import requests from olmo_data.data import get_data_path @@ -418,7 +419,20 @@ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]: # Google Storage API is unhinged and requires you to specify the retry policy on every single call you make. -_gcs_retry = GCSRetry(initial=1.0, maximum=10.0, multiplier=2.0, deadline=500.0) +def _gcs_is_retriable(exception: Exception) -> bool: + if gcs_is_transient_error(exception): + return True + if isinstance(exception, requests.exceptions.ReadTimeout): + return True + return False + + +_gcs_retry = GCSRetry( + predicate=_gcs_is_retriable, + initial=1.0, + maximum=10.0, + multiplier=2.0, + timeout=600.0) def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False): @@ -455,10 +469,9 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) try: - blob.reload(retry=_gcs_retry) + return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1, retry=_gcs_retry) except NotFound: raise FileNotFoundError(f"gs://{bucket_name}/{key}") - return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1, retry=_gcs_retry) @cache From 95db47f5c905c24ebaefc460605cc82185b3bac3 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Fri, 1 Nov 2024 17:45:14 -0700 Subject: [PATCH 09/11] Changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index dcd076b6a..8899e60ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `constant_with_warmup` learning rate schedule - `one_in_eight` configuration for activation checkpointing - New tokenizer in the source instead of from huggingface +- Improved support for GCS ## [v0.5.1](https://github.com/allenai/OLMo/releases/tag/v0.5.1) - 2024-10-17 From 2d3800bb7b38c87cea54f50d054f9fc1b7f39a57 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Fri, 1 Nov 2024 17:46:00 -0700 Subject: [PATCH 10/11] Productivity through formatting --- olmo/util.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index c2a7a9a0c..63bc75883 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -427,12 +427,7 @@ def _gcs_is_retriable(exception: Exception) -> bool: return False -_gcs_retry = GCSRetry( - predicate=_gcs_is_retriable, - initial=1.0, - maximum=10.0, - multiplier=2.0, - timeout=600.0) +_gcs_retry = GCSRetry(predicate=_gcs_is_retriable, initial=1.0, maximum=10.0, multiplier=2.0, timeout=600.0) def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False): From 67132f682c152c4c8462be2954611c4268bf2e57 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 6 Nov 2024 16:28:31 -0800 Subject: [PATCH 11/11] isort --- olmo/util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index 63bc75883..c448d7e2c 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -19,16 +19,17 @@ import boto3 import botocore.exceptions as boto_exceptions import datasets +import requests import rich from botocore.config import Config from cached_path.schemes import SchemeClient, add_scheme_client -from google.api_core.retry import Retry as GCSRetry, if_transient_error as gcs_is_transient_error +from google.api_core.retry import Retry as GCSRetry +from google.api_core.retry import if_transient_error as gcs_is_transient_error from rich.console import Console, ConsoleRenderable from rich.highlighter import NullHighlighter from rich.progress import Progress from rich.text import Text from rich.traceback import Traceback -import requests from olmo_data.data import get_data_path