diff --git a/CHANGELOG.md b/CHANGELOG.md index 49f4d3ac..287dba3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#184](https://github.com/Lightning-AI/utilities/pull/184), [#185](https://github.com/Lightning-AI/utilities/pull/185)) +- Added `rank_zero_only(..., default=)` argument to return a default value on rank > 1 ([#187](https://github.com/Lightning-AI/utilities/pull/187)) + ### Changed diff --git a/src/lightning_utilities/core/rank_zero.py b/src/lightning_utilities/core/rank_zero.py index c0546be0..448013f5 100644 --- a/src/lightning_utilities/core/rank_zero.py +++ b/src/lightning_utilities/core/rank_zero.py @@ -9,7 +9,7 @@ from platform import python_version from typing import Any, Callable, Optional, TypeVar, Union -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, overload log = logging.getLogger(__name__) @@ -17,7 +17,17 @@ P = ParamSpec("P") +@overload def rank_zero_only(fn: Callable[P, T]) -> Callable[P, Optional[T]]: + ... + + +@overload +def rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]: + ... + + +def rank_zero_only(fn: Callable[P, T], default: Optional[T] = None) -> Callable[P, Optional[T]]: """Wrap a function to call internal function only in rank zero. Function that can be used as a decorator to enable a function/method being called only on global rank 0. @@ -31,7 +41,7 @@ def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") if rank == 0: return fn(*args, **kwargs) - return None + return default return wrapped_fn diff --git a/tests/unittests/core/test_rank_zero.py b/tests/unittests/core/test_rank_zero.py index 523dbd43..b96b3d13 100644 --- a/tests/unittests/core/test_rank_zero.py +++ b/tests/unittests/core/test_rank_zero.py @@ -15,3 +15,14 @@ def test_rank_prefixed_message(rank): assert message == f"[rank: {rank}] bar" # reset del rank_zero_only.rank + + +def test_rank_zero_only_default(): + foo = lambda: "foo" + rank_zero_foo = rank_zero_only(foo, "not foo") + + rank_zero_only.rank = 0 + assert rank_zero_foo() == "foo" + + rank_zero_only.rank = 1 + assert rank_zero_foo() == "not foo"