Skip to content

Commit 601c997

Browse files
author
Orbax Authors
committed
Add additional support for TENSORSTORE_GCS_BACKEND environment variable
This allows the user to configure the Tensorstore GCS backend (i.e. with `gcs` for http or `gcs_grpc` for grpc) for additional places in the code that did not previously have support, including when using ocdbt. PiperOrigin-RevId: 809285436
1 parent 493fcf8 commit 601c997

File tree

5 files changed

+150
-22
lines changed

5 files changed

+150
-22
lines changed

checkpoint/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616

1717
- #v1 Modify LeafHandler definitions so that `AbstractLeaf` or
1818
`Type[AbstractLeaf]` are always accepted as valid abstract values.
19+
- Configuring the `TENSORSTORE_GCS_BACKEND` environment variable is now
20+
supported for additional locations in the code, notably when using ocdbt.
1921

2022
## [0.11.25] - 2025-09-10
2123

checkpoint/orbax/checkpoint/_src/serialization/serialization.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,22 @@ def _get_kvstore_for_gcs(ckpt_path: str):
7979
"""
8080
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
8181
if m is None:
82-
raise ValueError(
83-
'The ckpt_path should contain the bucket name and the '
84-
f'file path inside the bucket. Got: {ckpt_path}'
85-
)
82+
# The path might only be a bucket name.
83+
m = re.fullmatch('^gs://([^/]*)$', ckpt_path, re.DOTALL)
84+
if m is None:
85+
raise ValueError(
86+
'The ckpt_path should contain the bucket name and the '
87+
f'file path inside the bucket. Got: {ckpt_path}'
88+
)
8689
gcs_bucket = m.group(1)
87-
path_without_bucket = m.group(2)
90+
path_without_bucket = m.group(2) if m.lastindex == 2 else None
8891
# TODO(stoelinga): Switch to gcs_grpc by default.
8992
# gcs_grpc performs roughly twice as fast as gcs backend.
9093
gcs_backend = os.environ.get('TENSORSTORE_GCS_BACKEND', 'gcs')
91-
return {
92-
'driver': f'{gcs_backend}',
93-
'bucket': gcs_bucket,
94-
'path': path_without_bucket,
95-
}
94+
spec = {'driver': gcs_backend, 'bucket': gcs_bucket}
95+
if path_without_bucket:
96+
spec['path'] = path_without_bucket
97+
return spec
9698

9799

98100
def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
@@ -107,7 +109,7 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
107109
raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}')
108110
base_path = os.path.dirname(ckpt_path)
109111
base_driver_spec = (
110-
base_path
112+
_get_kvstore_for_gcs(base_path)
111113
if is_gcs_path
112114
else {'driver': ts_utils.DEFAULT_DRIVER, 'path': base_path}
113115
)

checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,14 @@ def test_get_tensorstore_spec_ocdbt(self, path):
475475
spec = serialization.get_tensorstore_spec(path, ocdbt=True)
476476
is_gcs_path = path.startswith('gs://')
477477
if is_gcs_path:
478-
self.assertEqual(spec['kvstore']['base'], os.path.dirname(path))
478+
self.assertEqual(
479+
spec['kvstore']['base'],
480+
{
481+
'driver': 'gcs',
482+
'bucket': 'my',
483+
'path': 'ckpt/dir',
484+
},
485+
)
479486
else:
480487
self.assertEqual(
481488
spec['kvstore']['base'],
@@ -493,6 +500,52 @@ def test_get_tensorstore_spec_not_absolute_path(self):
493500
):
494501
serialization.get_tensorstore_spec(path, ocdbt=True)
495502

503+
@parameterized.named_parameters(
504+
dict(testcase_name='none', backend=None, target_driver='gcs'),
505+
dict(testcase_name='gcs', backend='gcs', target_driver='gcs'),
506+
dict(
507+
testcase_name='gcs_grpc', backend='gcs_grpc', target_driver='gcs_grpc'
508+
),
509+
)
510+
def test_get_tensorstore_spec_ocdbt_grpc(self, backend, target_driver):
511+
if backend:
512+
os.environ['TENSORSTORE_GCS_BACKEND'] = backend
513+
self.addCleanup(lambda: os.environ.pop('TENSORSTORE_GCS_BACKEND'))
514+
spec = serialization.get_tensorstore_spec(
515+
'gs://my/ckpt/dir/path', ocdbt=True
516+
)
517+
self.assertEqual(
518+
spec['kvstore']['base'],
519+
{
520+
'driver': target_driver,
521+
'bucket': 'my',
522+
'path': 'ckpt/dir',
523+
},
524+
)
525+
526+
@parameterized.named_parameters(
527+
dict(testcase_name='none', backend=None, target_driver='gcs'),
528+
dict(testcase_name='gcs', backend='gcs', target_driver='gcs'),
529+
dict(
530+
testcase_name='gcs_grpc', backend='gcs_grpc', target_driver='gcs_grpc'
531+
),
532+
)
533+
def test_get_tensorstore_spec_grpc(self, backend, target_driver):
534+
if backend:
535+
os.environ['TENSORSTORE_GCS_BACKEND'] = backend
536+
self.addCleanup(lambda: os.environ.pop('TENSORSTORE_GCS_BACKEND'))
537+
spec = serialization.get_tensorstore_spec(
538+
'gs://my/ckpt/dir/path', ocdbt=False
539+
)
540+
self.assertEqual(
541+
spec['kvstore'],
542+
{
543+
'driver': target_driver,
544+
'bucket': 'my',
545+
'path': 'ckpt/dir/path',
546+
},
547+
)
548+
496549
def test_deserialization_with_int4(self):
497550
dtype = jnp.int4
498551
shape = (8, 2)

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
ZARR_VER3 = 'zarr3'
4545

4646
_GCS_PATH_RE = r'^gs://([^/]*)/(.*)$'
47+
_GCS_BUCKET_RE = r'^gs://([^/]*)$'
4748

4849
# Even if the data is equal to the fill value, we still want to write it
4950
# to the checkpoint. This results in unnecessary writes in some edge
@@ -112,15 +113,35 @@ def get_ts_context(
112113

113114

114115
def _get_kvstore_for_gcs(ckpt_path: str) -> JsonSpec:
116+
"""Constructs a TensorStore kvstore spec for a GCS path.
117+
118+
Args:
119+
ckpt_path: A GCS path of the form gs://<bucket>/<path>.
120+
121+
Returns:
122+
A dictionary containing the TensorStore kvstore spec.
123+
124+
Raises:
125+
ValueError: if ckpt_path is not a valid GCS path.
126+
"""
115127
m = re.fullmatch(_GCS_PATH_RE, ckpt_path, re.DOTALL)
116128
if m is None:
117-
raise ValueError(
118-
'The ckpt_path should contain the bucket name and the '
119-
f'file path inside the bucket. Got: {ckpt_path}'
120-
)
129+
# The path might only be a bucket name.
130+
m = re.fullmatch(_GCS_BUCKET_RE, ckpt_path, re.DOTALL)
131+
if m is None:
132+
raise ValueError(
133+
'The ckpt_path should contain the bucket name and the '
134+
f'file path inside the bucket. Got: {ckpt_path}'
135+
)
121136
gcs_bucket = m.group(1)
122-
path_without_bucket = m.group(2)
123-
return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket}
137+
path_without_bucket = m.group(2) if m.lastindex == 2 else None
138+
# TODO(mirvine): Switch to gcs_grpc by default.
139+
# gcs_grpc performs roughly twice as fast as gcs backend.
140+
gcs_backend = os.environ.get('TENSORSTORE_GCS_BACKEND', 'gcs')
141+
spec = {'driver': gcs_backend, 'bucket': gcs_bucket}
142+
if path_without_bucket:
143+
spec['path'] = path_without_bucket
144+
return spec
124145

125146

126147
def build_kvstore_tspec(
@@ -165,7 +186,7 @@ def build_kvstore_tspec(
165186
directory, f'{PROCESS_SUBDIR_PREFIX}{process_id}'
166187
)
167188
base_driver_spec = (
168-
directory
189+
_get_kvstore_for_gcs(str(directory))
169190
if is_gcs_path
170191
else {'driver': default_driver, 'path': str(directory)}
171192
)

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,12 @@ def test_ocdbt_kvstore(
203203
dict(
204204
testcase_name='regular_path',
205205
directory='gs://gcs_bucket/object_path',
206-
expected_directory=None,
206+
expected_directory='object_path',
207207
),
208208
dict(
209209
testcase_name='path_with_single_slash',
210210
directory='gs:/gcs_bucket/object_path',
211-
expected_directory='gs://gcs_bucket/object_path',
211+
expected_directory='object_path',
212212
),
213213
)
214214
def test_ocdbt_kvstore_with_gcs_path(
@@ -228,10 +228,60 @@ def test_ocdbt_kvstore_with_gcs_path(
228228
self.assertEqual(kvstore_tspec['driver'], 'ocdbt')
229229
self.assertEqual(
230230
kvstore_tspec['base'],
231-
os.path.join(expected_directory or directory, 'ocdbt.process_0'),
231+
{
232+
'driver': 'gcs',
233+
'bucket': 'gcs_bucket',
234+
'path': os.path.join(
235+
expected_directory or directory, 'ocdbt.process_0'
236+
),
237+
},
232238
)
233239
self.assertEqual(kvstore_tspec['path'], self.param_name)
234240

241+
@parameterized.named_parameters(
242+
dict(testcase_name='none', backend=None, target_driver='gcs'),
243+
dict(testcase_name='gcs', backend='gcs', target_driver='gcs'),
244+
dict(
245+
testcase_name='gcs_grpc', backend='gcs_grpc', target_driver='gcs_grpc'
246+
),
247+
)
248+
def test_get_tensorstore_spec_ocdbt_grpc(self, backend, target_driver):
249+
if backend:
250+
os.environ['TENSORSTORE_GCS_BACKEND'] = backend
251+
self.addCleanup(lambda: os.environ.pop('TENSORSTORE_GCS_BACKEND'))
252+
spec = ts_utils.build_kvstore_tspec('gs://my/ckpt/dir/path', use_ocdbt=True)
253+
self.assertEqual(
254+
spec['base'],
255+
{
256+
'driver': target_driver,
257+
'bucket': 'my',
258+
'path': 'ckpt/dir/path',
259+
},
260+
)
261+
262+
@parameterized.named_parameters(
263+
dict(testcase_name='none', backend=None, target_driver='gcs'),
264+
dict(testcase_name='gcs', backend='gcs', target_driver='gcs'),
265+
dict(
266+
testcase_name='gcs_grpc', backend='gcs_grpc', target_driver='gcs_grpc'
267+
),
268+
)
269+
def test_get_tensorstore_spec_grpc(self, backend, target_driver):
270+
if backend:
271+
os.environ['TENSORSTORE_GCS_BACKEND'] = backend
272+
self.addCleanup(lambda: os.environ.pop('TENSORSTORE_GCS_BACKEND'))
273+
spec = ts_utils.build_kvstore_tspec(
274+
'gs://my/ckpt/dir/path', use_ocdbt=False
275+
)
276+
self.assertEqual(
277+
spec,
278+
{
279+
'driver': target_driver,
280+
'bucket': 'my',
281+
'path': 'ckpt/dir/path',
282+
},
283+
)
284+
235285
@parameterized.product(use_zarr3=(True, False))
236286
def test_ocdbt_kvstore_default_target_data_file_size(self, use_zarr3: bool):
237287
tspec = self.array_write_spec_constructor(

0 commit comments

Comments
 (0)