Skip to content

Commit 84e3246

Browse files
authored
Issue 4 recurse late-bound values and allow functions (#5)
* restore `__` as an alias for late * recurse over late bound values and allow functions fixes #4
1 parent 55ba3ba commit 84e3246

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

Diff for: late/__init__.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,40 @@ class _LateBound(NamedTuple):
1515
actual: Any
1616

1717

18-
def __(o: _T | Iterator[_V]) -> _T | _V:
18+
def late(o: _T | Iterator[_V]) -> _T | _V:
1919
if isinstance(o, int | float | str | bool | bytes | bytearray | frozenset):
2020
return o # type: ignore
21+
22+
actual: Any = None
23+
if isinstance(o, list):
24+
actual = [late(value) for value in o]
25+
elif isinstance(o, dict):
26+
actual = {name: late(value) for name, value in o.items()}
27+
elif isinstance(o, set):
28+
actual = {late(value) for value in o}
2129
else:
22-
return _LateBound(actual=o) # type: ignore
30+
actual = o
31+
return _LateBound(actual=actual) # type: ignore
32+
33+
34+
__ = late
2335

2436

2537
def _lateargs(func: Callable, **kwargs) -> dict[str, Any]:
2638

2739
def resolve_default(value):
2840
if inspect.isgenerator(value):
2941
return next(value)
42+
if inspect.isfunction(value):
43+
return value()
44+
if isinstance(value, _LateBound):
45+
return resolve_default(value.actual)
46+
if isinstance(value, list):
47+
return [resolve_default(x) for x in value]
48+
if isinstance(value, dict):
49+
return {name: resolve_default(x) for name, x in value.items()}
50+
if isinstance(value, set):
51+
return {resolve_default(x) for x in value}
3052
return copy.copy(value)
3153

3254
lateargs = {

Diff for: test/complex_test.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import inspect
2+
from typing import Any
3+
4+
from late import latebinding, __, _LateBound
5+
6+
7+
def test_recursion():
8+
@latebinding
9+
def f(x: list[list[Any]] = __([[]])) -> list[list[Any]]:
10+
x[0].append(1)
11+
return x
12+
13+
assert f() == [[1]]
14+
assert f() == [[1]]
15+
16+
17+
def test_function():
18+
t = 0
19+
20+
def a() -> int:
21+
nonlocal t
22+
t += 1
23+
return t
24+
25+
@latebinding
26+
def f(x: int = __(a)) -> int: # type: ignore
27+
return 2 * x
28+
29+
assert f() == 2
30+
assert f() == 4

0 commit comments

Comments
 (0)