@@ -399,15 +399,16 @@ def inner(*args, **kwds):
399
399
400
400
return decorator
401
401
402
- def _eval_type (t , globalns , localns , recursive_guard = frozenset ()):
402
+
403
+ def _eval_type (t , globalns , localns , type_params , * , recursive_guard = frozenset ()):
403
404
"""Evaluate all forward references in the given type t.
404
405
405
406
For use of globalns and localns see the docstring for get_type_hints().
406
407
recursive_guard is used to prevent infinite recursion with a recursive
407
408
ForwardRef.
408
409
"""
409
410
if isinstance (t , ForwardRef ):
410
- return t ._evaluate (globalns , localns , recursive_guard )
411
+ return t ._evaluate (globalns , localns , type_params , recursive_guard = recursive_guard )
411
412
if isinstance (t , (_GenericAlias , GenericAlias , types .UnionType )):
412
413
if isinstance (t , GenericAlias ):
413
414
args = tuple (
@@ -421,7 +422,13 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
421
422
t = t .__origin__ [args ]
422
423
if is_unpacked :
423
424
t = Unpack [t ]
424
- ev_args = tuple (_eval_type (a , globalns , localns , recursive_guard ) for a in t .__args__ )
425
+
426
+ ev_args = tuple (
427
+ _eval_type (
428
+ a , globalns , localns , type_params , recursive_guard = recursive_guard
429
+ )
430
+ for a in t .__args__
431
+ )
425
432
if ev_args == t .__args__ :
426
433
return t
427
434
if isinstance (t , GenericAlias ):
@@ -974,7 +981,7 @@ def __init__(self, arg, is_argument=True, module=None, *, is_class=False):
974
981
self .__forward_is_class__ = is_class
975
982
self .__forward_module__ = module
976
983
977
- def _evaluate (self , globalns , localns , recursive_guard ):
984
+ def _evaluate (self , globalns , localns , type_params , * , recursive_guard ):
978
985
if self .__forward_arg__ in recursive_guard :
979
986
return self
980
987
if not self .__forward_evaluated__ or localns is not globalns :
@@ -988,14 +995,25 @@ def _evaluate(self, globalns, localns, recursive_guard):
988
995
globalns = getattr (
989
996
sys .modules .get (self .__forward_module__ , None ), '__dict__' , globalns
990
997
)
998
+ if type_params :
999
+ # "Inject" type parameters into the local namespace
1000
+ # (unless they are shadowed by assignments *in* the local namespace),
1001
+ # as a way of emulating annotation scopes when calling `eval()`
1002
+ locals_to_pass = {param .__name__ : param for param in type_params } | localns
1003
+ else :
1004
+ locals_to_pass = localns
991
1005
type_ = _type_check (
992
- eval (self .__forward_code__ , globalns , localns ),
1006
+ eval (self .__forward_code__ , globalns , locals_to_pass ),
993
1007
"Forward references must evaluate to types." ,
994
1008
is_argument = self .__forward_is_argument__ ,
995
1009
allow_special_forms = self .__forward_is_class__ ,
996
1010
)
997
1011
self .__forward_value__ = _eval_type (
998
- type_ , globalns , localns , recursive_guard | {self .__forward_arg__ }
1012
+ type_ ,
1013
+ globalns ,
1014
+ localns ,
1015
+ type_params ,
1016
+ recursive_guard = (recursive_guard | {self .__forward_arg__ }),
999
1017
)
1000
1018
self .__forward_evaluated__ = True
1001
1019
return self .__forward_value__
@@ -2334,7 +2352,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
2334
2352
value = type (None )
2335
2353
if isinstance (value , str ):
2336
2354
value = ForwardRef (value , is_argument = False , is_class = True )
2337
- value = _eval_type (value , base_globals , base_locals )
2355
+ value = _eval_type (value , base_globals , base_locals , base . __type_params__ )
2338
2356
hints [name ] = value
2339
2357
return hints if include_extras else {k : _strip_annotations (t ) for k , t in hints .items ()}
2340
2358
@@ -2360,6 +2378,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
2360
2378
raise TypeError ('{!r} is not a module, class, method, '
2361
2379
'or function.' .format (obj ))
2362
2380
hints = dict (hints )
2381
+ type_params = getattr (obj , "__type_params__" , ())
2363
2382
for name , value in hints .items ():
2364
2383
if value is None :
2365
2384
value = type (None )
@@ -2371,7 +2390,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
2371
2390
is_argument = not isinstance (obj , types .ModuleType ),
2372
2391
is_class = False ,
2373
2392
)
2374
- hints [name ] = _eval_type (value , globalns , localns )
2393
+ hints [name ] = _eval_type (value , globalns , localns , type_params )
2375
2394
return hints if include_extras else {k : _strip_annotations (t ) for k , t in hints .items ()}
2376
2395
2377
2396
0 commit comments