Skip to content

Commit 49a5649

Browse files
authored
Fix inheritance edge cases for type parameter resolution (#230)
* Fix issue with generics when TypeVars are reused * Refine * Fix type param resolution for superclass __init__ / __new__ * pydantic test * attrs * Tests * Python 3.8 hack * Minor update for coverage * ruff * mypy * upgrade ruff * remove redundant annotations in tests * no cover for py38 hack * run test gen
1 parent 6761eed commit 49a5649

14 files changed

+465
-52
lines changed

src/tyro/_fields.py

+73-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import dataclasses
88
import functools
99
import inspect
10+
import sys
1011
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
1112

1213
import docstring_parser
@@ -74,7 +75,7 @@ def make(
7475
helptext: Optional[str],
7576
call_argname_override: Optional[Any] = None,
7677
):
77-
# Resolve generics.
78+
# Resolve type parameters.
7879
typ = _resolver.TypeParamResolver.concretize_type_params(typ)
7980

8081
# Narrow types.
@@ -265,14 +266,17 @@ def _field_list_from_function(
265266

266267
# Unwrap functools.wraps and functools.partial.
267268
done = False
269+
functools_marker = False
268270
while not done:
269271
done = True
270272
if hasattr(f, "__wrapped__"):
271273
f = f.__wrapped__ # type: ignore
272274
done = False
275+
functools_marker = True
273276
if isinstance(f, functools.partial):
274277
f = f.func
275278
done = False
279+
functools_marker = True
276280

277281
# Check for abstract classes.
278282
if inspect.isabstract(f):
@@ -282,11 +286,71 @@ def _field_list_from_function(
282286
# signature. But the docstrings may still be in the class signature itself.
283287
f_before_init_unwrap = f
284288

289+
hints = None
290+
285291
if inspect.isclass(f):
292+
signature_func = None
286293
if hasattr(f, "__init__") and f.__init__ is not object.__init__:
287-
f = f.__init__ # type: ignore
294+
signature_func = "__init__"
288295
elif hasattr(f, "__new__") and f.__new__ is not object.__new__:
289-
f = f.__new__
296+
signature_func = "__new__"
297+
298+
if signature_func is not None:
299+
# Get the __init__ / __new__ method from the class, as well as the
300+
# class that contains it.
301+
#
302+
# We call this the "signature function", because it's the function
303+
# that we use to instantiate the class.
304+
orig_cls = f
305+
base_cls_with_signature = None
306+
for base_cls_with_signature in inspect.getmro(f):
307+
if signature_func in base_cls_with_signature.__dict__:
308+
f = getattr(base_cls_with_signature, signature_func)
309+
break
310+
assert base_cls_with_signature is not None
311+
assert f is not orig_cls
312+
313+
# For older versions of Python, the signature returned above (when
314+
# passed through generics base classes) will sometimes be (*args,
315+
# **kwargs).
316+
#
317+
# This is a hack. We can remove it if we deprecate Python 3.8 support.
318+
if sys.version_info < (3, 9) and not functools_marker: # pragma: no cover
319+
params = list(inspect.signature(f).parameters.values())[1:]
320+
321+
# Get hints for the signature function by recursing through the
322+
# inheritance tree. This is needed to correctly resolve type
323+
# parameters, which can be set anywhere between the input class and
324+
# the class where the __init__ or __new__ method is defined.
325+
def get_hints_for_signature_func(cls):
326+
typevar_context = _resolver.TypeParamResolver.get_assignment_context(
327+
cls
328+
)
329+
cls = typevar_context.origin_type
330+
with typevar_context:
331+
if cls is base_cls_with_signature:
332+
return _resolver.get_type_hints_resolve_type_params(
333+
f, include_extras=True
334+
)
335+
for base_cls in (
336+
cls.__orig_bases__
337+
if hasattr(cls, "__orig_bases__")
338+
else cls.__bases__
339+
):
340+
if not issubclass(
341+
_resolver.unwrap_origin_strip_extras(base_cls),
342+
base_cls_with_signature,
343+
):
344+
continue
345+
return get_hints_for_signature_func(
346+
_resolver.TypeParamResolver.concretize_type_params(base_cls)
347+
)
348+
349+
assert False, (
350+
"We couldn't find the base class. This seems like a bug in tyro."
351+
)
352+
353+
hints = get_hints_for_signature_func(orig_cls)
290354

291355
# Get type annotations, docstrings.
292356
docstring = inspect.getdoc(f)
@@ -296,11 +360,13 @@ def _field_list_from_function(
296360
docstring_from_arg_name[param_doc.arg_name] = param_doc.description
297361
del docstring
298362

363+
# Get hints if we haven't done it already.
299364
# This will throw a type error for torch.device, typing.Dict, etc.
300-
try:
301-
hints = _resolver.get_type_hints_with_backported_syntax(f, include_extras=True)
302-
except TypeError:
303-
return UnsupportedStructTypeMessage(f"Could not get hints for {f}!")
365+
if hints is None:
366+
try:
367+
hints = _resolver.get_type_hints_resolve_type_params(f, include_extras=True)
368+
except TypeError:
369+
return UnsupportedStructTypeMessage(f"Could not get hints for {f}!")
304370

305371
field_list = []
306372
for param in params:

src/tyro/_resolver.py

+49-21
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Callable,
1515
ClassVar,
1616
Dict,
17+
Generic,
1718
List,
1819
Sequence,
1920
Set,
@@ -76,14 +77,14 @@ def is_dataclass(cls: Union[TypeForm, Callable]) -> bool:
7677
return dataclasses.is_dataclass(unwrap_origin_strip_extras(cls)) # type: ignore
7778

7879

79-
@_unsafe_cache.unsafe_cache(maxsize=1024)
80+
# @_unsafe_cache.unsafe_cache(maxsize=1024)
8081
def resolved_fields(cls: TypeForm) -> List[dataclasses.Field]:
8182
"""Similar to dataclasses.fields(), but includes dataclasses.InitVar types and
8283
resolves forward references."""
8384

8485
assert dataclasses.is_dataclass(cls)
8586
fields = []
86-
annotations = get_type_hints_with_backported_syntax(
87+
annotations = get_type_hints_resolve_type_params(
8788
cast(Callable, cls), include_extras=True
8889
)
8990
for field in getattr(cls, "__dataclass_fields__").values():
@@ -566,24 +567,6 @@ def resolve_generic_types(
566567
assert len(typevars) == len(typevar_values)
567568
typ = origin_cls
568569
type_from_typevar.update(dict(zip(typevars, typevar_values)))
569-
elif (
570-
# Apply some heuristics for generic types. Should revisit this.
571-
hasattr(typ, "__parameters__") and hasattr(typ.__parameters__, "__len__") # type: ignore
572-
):
573-
typevars = typ.__parameters__ # type: ignore
574-
typevar_values = tuple(type_from_typevar_constraints(x) for x in typevars)
575-
assert len(typevars) == len(typevar_values)
576-
type_from_typevar.update(dict(zip(typevars, typevar_values)))
577-
578-
if hasattr(typ, "__orig_bases__"):
579-
bases = getattr(typ, "__orig_bases__")
580-
for base in bases:
581-
origin_base = unwrap_origin_strip_extras(base)
582-
if origin_base is base or not hasattr(origin_base, "__parameters__"):
583-
continue
584-
typevars = origin_base.__parameters__
585-
typevar_values = get_args(base)
586-
type_from_typevar.update(dict(zip(typevars, typevar_values)))
587570

588571
if len(annotations) == 0:
589572
return typ, type_from_typevar
@@ -594,11 +577,56 @@ def resolve_generic_types(
594577
)
595578

596579

597-
def get_type_hints_with_backported_syntax(
580+
def get_type_hints_resolve_type_params(
581+
obj: Callable[..., Any], include_extras: bool = False
582+
) -> Dict[str, Any]:
583+
"""Variant of `typing.get_type_hints()` that resolves type parameters."""
584+
if not inspect.isclass(obj):
585+
return {
586+
k: TypeParamResolver.concretize_type_params(v)
587+
for k, v in _get_type_hints_backported_syntax(
588+
obj, include_extras=include_extras
589+
).items()
590+
}
591+
592+
typevar_context = TypeParamResolver.get_assignment_context(obj)
593+
obj = typevar_context.origin_type
594+
with typevar_context:
595+
# Only include type hints that are explicitly defined in this class.
596+
# The follow loop will handle superclasses.
597+
out = {
598+
x: TypeParamResolver.concretize_type_params(t)
599+
for x, t in _get_type_hints_backported_syntax(
600+
obj, include_extras=include_extras
601+
).items()
602+
if x in obj.__annotations__
603+
}
604+
605+
# We need to recurse into base classes in order to correctly resolve superclass parameters.
606+
for base in obj.__orig_bases__ if hasattr(obj, "__orig_bases__") else obj.__bases__: # type: ignore
607+
base_typevar_context = TypeParamResolver.get_assignment_context(base)
608+
if get_origin(base_typevar_context.origin_type) is Generic:
609+
continue
610+
with base_typevar_context:
611+
base_hints = get_type_hints_resolve_type_params(
612+
base_typevar_context.origin_type, include_extras=include_extras
613+
)
614+
out.update(
615+
{
616+
x: TypeParamResolver.concretize_type_params(t)
617+
for x, t in base_hints.items()
618+
}
619+
)
620+
621+
return out
622+
623+
624+
def _get_type_hints_backported_syntax(
598625
obj: Callable[..., Any], include_extras: bool = False
599626
) -> Dict[str, Any]:
600627
"""Same as `typing.get_type_hints()`, but supports new union syntax (X | Y)
601628
and generics (list[str]) in older versions of Python."""
629+
602630
try:
603631
out = get_type_hints(obj, include_extras=include_extras)
604632

src/tyro/constructors/_registry.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,12 @@ def get_struct_spec(cls, type_info: StructTypeInfo) -> StructConstructorSpec | N
173173
f"Invalid default instance for type {type_info.type}: {type_info.default}"
174174
)
175175

176-
for registry in cls._active_registries[::-1]:
177-
for spec_factory in registry._struct_rules[::-1]:
178-
maybe_spec = spec_factory(type_info)
179-
if maybe_spec is not None:
180-
return maybe_spec
176+
with type_info._typevar_context:
177+
for registry in cls._active_registries[::-1]:
178+
for spec_factory in registry._struct_rules[::-1]:
179+
maybe_spec = spec_factory(type_info)
180+
if maybe_spec is not None:
181+
return maybe_spec
181182

182183
return None
183184

src/tyro/constructors/_struct_spec.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def typeddict_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
214214
total = getattr(cls, "__total__", True)
215215
assert isinstance(total, bool)
216216
assert not valid_default_instance or isinstance(info.default, dict)
217-
for name, typ in _resolver.get_type_hints_with_backported_syntax(
217+
for name, typ in _resolver.get_type_hints_resolve_type_params(
218218
cls, include_extras=True
219219
).items():
220220
typ_origin = get_origin(typ)
@@ -277,7 +277,13 @@ def attrs_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
277277
return None
278278

279279
# Resolve forward references in-place, if any exist.
280-
attr.resolve_types(info.type)
280+
# attr.resolve_types(info.type)
281+
282+
# We'll use our own type resolution system instead of attr's. This is
283+
# primarily to improve generics support.
284+
our_hints = _resolver.get_type_hints_resolve_type_params(
285+
info.type, include_extras=True
286+
)
281287

282288
# Handle attr classes.
283289
field_list = []
@@ -301,7 +307,7 @@ def attrs_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
301307
field_list.append(
302308
StructFieldSpec(
303309
name=name,
304-
type=attr_field.type,
310+
type=our_hints[name],
305311
default=default,
306312
helptext=_docstrings.get_field_docstring(info.type, name),
307313
)
@@ -353,7 +359,7 @@ def namedtuple_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
353359
field_list = []
354360
field_defaults = getattr(info.type, "_field_defaults", {})
355361

356-
for name, typ in _resolver.get_type_hints_with_backported_syntax(
362+
for name, typ in _resolver.get_type_hints_resolve_type_params(
357363
info.type, include_extras=True
358364
).items():
359365
default = field_defaults.get(name, MISSING_NONPROP)
@@ -534,7 +540,7 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
534540
):
535541
# Pydantic 1.xx
536542
cls_cast = info.type
537-
hints = _resolver.get_type_hints_with_backported_syntax(
543+
hints = _resolver.get_type_hints_resolve_type_params(
538544
info.type, include_extras=True
539545
)
540546
for pd1_field in cast(Dict[str, Any], cls_cast.__fields__).values():

tests/test_attrs.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import contextlib
44
import io
55
import pathlib
6-
from typing import cast
6+
from typing import Generic, TypeVar, cast
77

88
import attr
99
import pytest
1010
from attrs import define, field
11+
from helptext_utils import get_helptext_with_checks
1112

1213
import tyro
1314
import tyro._strings
@@ -129,3 +130,24 @@ class ManyTypesB:
129130
args=["--i", "5"],
130131
default=ManyTypesB(i=5, s="5", f=2.0),
131132
) == ManyTypesB(i=5, s="5", f=2.0)
133+
134+
135+
T = TypeVar("T")
136+
137+
138+
def test_attrs_inheritance_with_same_typevar() -> None:
139+
@attr.s
140+
class A(Generic[T]):
141+
x: T = attr.ib()
142+
143+
@attr.s
144+
class B(A[int], Generic[T]):
145+
y: T = attr.ib()
146+
147+
assert "INT" in get_helptext_with_checks(B[int])
148+
assert "STR" not in get_helptext_with_checks(B[int])
149+
assert "STR" in get_helptext_with_checks(B[str])
150+
assert "INT" in get_helptext_with_checks(B[str])
151+
152+
assert tyro.cli(B[str], args=["--x", "1", "--y", "2"]) == B(x=1, y="2")
153+
assert tyro.cli(B[int], args=["--x", "1", "--y", "2"]) == B(x=1, y=2)

0 commit comments

Comments
 (0)