Skip to content

Commit

Permalink
Move NoopFuture, ChainedFuture and ThreadRaisingException from …
Browse files Browse the repository at this point in the history
…`orbax.checkpoint.futures` to `orbax.checkpoint._src.futures`

PiperOrigin-RevId: 718244371
  • Loading branch information
mridul-sahu authored and Orbax Authors committed Jan 22, 2025
1 parent f7bbe80 commit fbc1cc0
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 63 deletions.
59 changes: 58 additions & 1 deletion checkpoint/orbax/checkpoint/_src/futures/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""Futures that can be used for signaling for synchronization."""

import threading
from typing import Any, Coroutine, Optional
import time
from typing import Any, Callable, Coroutine, Optional, Sequence

from absl import logging
from orbax.checkpoint._src import asyncio_utils
Expand Down Expand Up @@ -48,6 +49,62 @@ def result(self, timeout: Optional[int] = None) -> Any:
...


class NoopFuture:

def result(self, timeout: Optional[int] = None) -> Any:
del timeout
return None


class ChainedFuture:
"""A future representing a sequence of multiple futures."""

def __init__(self, futures: Sequence[Future], cb: Callable[[], None]):
self._futures = futures
self._cb = cb

def result(self, timeout: Optional[int] = None) -> Any:
"""Waits for all futures to complete."""
n = len(self._futures)
start = time.time()
time_remaining = timeout
for k, f in enumerate(self._futures):
f.result(timeout=time_remaining)
if time_remaining is not None:
time_elapsed = time.time() - start
time_remaining -= time_elapsed
if time_remaining <= 0:
raise TimeoutError(
'ChainedFuture completed {:d}/{:d} futures but timed out after'
' {:.2f} seconds.'.format(k, n, time_elapsed)
)
time_elapsed = time.time() - start
logging.info(
'ChainedFuture completed %d/%d futures in %.2f seconds.',
n,
n,
time_elapsed,
)
self._cb()


class ThreadRaisingException(threading.Thread):
"""Thread that raises an exception if it encounters an error."""

_exception: Optional[Exception] = None

def run(self):
try:
super().run()
except Exception as e: # pylint: disable=broad-exception-caught
self._exception = e

def join(self, timeout=None):
super().join(timeout=timeout)
if self._exception is not None:
raise self._exception


class _SignalingThread(threading.Thread):
"""Thread that raises an exception if it encounters an error.
Expand Down
67 changes: 5 additions & 62 deletions checkpoint/orbax/checkpoint/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,66 +14,9 @@

"""Orbax Future class used for duck typing."""

import threading
import time
from typing import Any, Callable, Optional, Sequence
from absl import logging
from orbax.checkpoint._src.futures import future
# pylint: disable=g-importing-member, unused-import


Future = future.Future


class NoopFuture:

def result(self, timeout: Optional[int] = None) -> Any:
del timeout
return None


class ThreadRaisingException(threading.Thread):
"""Thread that raises an exception if it encounters an error."""
_exception: Optional[Exception] = None

def run(self):
try:
super().run()
except Exception as e: # pylint: disable=broad-exception-caught
self._exception = e

def join(self, timeout=None):
super().join(timeout=timeout)
if self._exception is not None:
raise self._exception


class ChainedFuture:
"""A future representing a sequence of multiple futures."""

def __init__(self, futures: Sequence[Future], cb: Callable[[], None]):
self._futures = futures
self._cb = cb

def result(self, timeout: Optional[int] = None) -> Any:
"""Waits for all futures to complete."""
n = len(self._futures)
start = time.time()
time_remaining = timeout
for k, f in enumerate(self._futures):
f.result(timeout=time_remaining)
if time_remaining is not None:
time_elapsed = time.time() - start
time_remaining -= time_elapsed
if time_remaining <= 0:
raise TimeoutError(
'ChainedFuture completed {:d}/{:d} futures but timed out after'
' {:.2f} seconds.'.format(k, n, time_elapsed)
)
time_elapsed = time.time() - start
logging.info(
'ChainedFuture completed %d/%d futures in %.2f seconds.',
n,
n,
time_elapsed,
)
self._cb()
from orbax.checkpoint._src.futures.future import ChainedFuture
from orbax.checkpoint._src.futures.future import Future
from orbax.checkpoint._src.futures.future import NoopFuture
from orbax.checkpoint._src.futures.future import ThreadRaisingException

0 comments on commit fbc1cc0

Please sign in to comment.