1
- import collections
1
+ from __future__ import annotations
2
+
2
3
from keyword import iskeyword
3
- from typing import (
4
- Any ,
5
- Dict ,
6
- List ,
7
- Literal ,
8
- Mapping ,
9
- MutableMapping ,
10
- Optional ,
11
- Sequence ,
12
- Tuple ,
13
- Type ,
14
- Union ,
15
- get_args ,
16
- get_origin ,
17
- )
18
-
19
- from .exceptions import ConfigValidationError
4
+ from typing import Any , Mapping , Sequence , get_type_hints
5
+
6
+ from ..exceptions import ConfigValidationError
7
+ from .annotations import TypeAnnotation
20
8
21
9
NoValue = object ()
22
10
@@ -26,7 +14,7 @@ class PoeOptions:
26
14
A special kind of config object that parses options ...
27
15
"""
28
16
29
- __annotations : Dict [str , Type ]
17
+ __annotations : dict [str , TypeAnnotation ]
30
18
31
19
def __init__ (self , ** options : Any ):
32
20
for key in self .get_fields ():
@@ -61,13 +49,13 @@ def __getattr__(self, name: str):
61
49
@classmethod
62
50
def parse (
63
51
cls ,
64
- source : Union [ Mapping [str , Any ], list ] ,
52
+ source : Mapping [str , Any ] | list ,
65
53
strict : bool = True ,
66
54
extra_keys : Sequence [str ] = tuple (),
67
55
):
68
56
config_keys = {
69
- key [:- 1 ] if key .endswith ("_" ) and iskeyword (key [:- 1 ]) else key : vtype
70
- for key , vtype in cls .get_fields ().items ()
57
+ key [:- 1 ] if key .endswith ("_" ) and iskeyword (key [:- 1 ]) else key : type_
58
+ for key , type_ in cls .get_fields ().items ()
71
59
}
72
60
if strict :
73
61
for index , item in enumerate (cls .normalize (source , strict )):
@@ -110,29 +98,8 @@ def _parse_value(
110
98
return value_type .parse (value , strict = strict )
111
99
112
100
if strict :
113
- expected_type : Union [Type , Tuple [Type , ...]] = cls ._type_of (value_type )
114
- if not isinstance (value , expected_type ):
115
- # Try format expected_type nicely in the error message
116
- if not isinstance (expected_type , tuple ):
117
- expected_type = (expected_type ,)
118
- formatted_type = " | " .join (
119
- type_ .__name__ for type_ in expected_type if type_ is not type (None )
120
- )
121
- raise ConfigValidationError (
122
- f"Option { key !r} should have a value of type: { formatted_type } " ,
123
- index = index ,
124
- )
125
-
126
- annotation = cls .get_annotation (key )
127
- if get_origin (annotation ) is Literal :
128
- allowed_values = get_args (annotation )
129
- if value not in allowed_values :
130
- raise ConfigValidationError (
131
- f"Option { key !r} must be one of { allowed_values !r} " ,
132
- index = index ,
133
- )
134
-
135
- # TODO: validate list/dict contents
101
+ for error_msg in value_type .validate ((key ,), value ):
102
+ raise ConfigValidationError (error_msg , index = index )
136
103
137
104
return value
138
105
@@ -171,43 +138,25 @@ def get(self, key: str, default: Any = NoValue) -> Any:
171
138
if default is NoValue :
172
139
# Fallback to getting getting the zero value for the type of this attribute
173
140
# e.g. 0, False, empty list, empty dict, etc
174
- return self .__get_zero_value (key )
141
+ annotation = self .get_fields ().get (self ._resolve_key (key ))
142
+ assert annotation
143
+ return annotation .zero_value ()
175
144
176
145
return default
177
146
178
- def __get_zero_value (self , key : str ):
179
- type_of_attr = self .type_of (key )
180
- if isinstance (type_of_attr , tuple ):
181
- if type (None ) in type_of_attr :
182
- # Optional types default to None
183
- return None
184
- type_of_attr = type_of_attr [0 ]
185
- assert type_of_attr
186
- return type_of_attr ()
187
-
188
147
def __is_optional (self , key : str ):
189
- # TODO: precache optional options keys?
190
- type_of_attr = self .type_of (key )
191
- if isinstance (type_of_attr , tuple ):
192
- return type (None ) in type_of_attr
193
- return False
148
+ annotation = self .get_fields ().get (self ._resolve_key (key ))
149
+ assert annotation
150
+ return annotation .is_optional
194
151
195
- def update (self , options_dict : Dict [str , Any ]):
152
+ def update (self , options_dict : dict [str , Any ]):
196
153
new_options_dict = {}
197
154
for key in self .get_fields ().keys ():
198
155
if key in options_dict :
199
156
new_options_dict [key ] = options_dict [key ]
200
157
elif hasattr (self , key ):
201
158
new_options_dict [key ] = getattr (self , key )
202
159
203
- @classmethod
204
- def type_of (cls , key : str ) -> Optional [Union [Type , Tuple [Type , ...]]]:
205
- return cls ._type_of (cls .get_annotation (key ))
206
-
207
- @classmethod
208
- def get_annotation (cls , key : str ) -> Optional [Type ]:
209
- return cls .get_fields ().get (cls ._resolve_key (key ))
210
-
211
160
@classmethod
212
161
def _resolve_key (cls , key : str ) -> str :
213
162
"""
@@ -219,51 +168,19 @@ def _resolve_key(cls, key: str) -> str:
219
168
return key
220
169
221
170
@classmethod
222
- def _type_of (cls , annotation : Any ) -> Union [Type , Tuple [Type , ...]]:
223
- if get_origin (annotation ) is Union :
224
- result : List [Type ] = []
225
- for component in get_args (annotation ):
226
- component_type = cls ._type_of (component )
227
- if isinstance (component_type , tuple ):
228
- result .extend (component_type )
229
- else :
230
- result .append (component_type )
231
- return tuple (result )
232
-
233
- if get_origin (annotation ) in (
234
- dict ,
235
- Mapping ,
236
- MutableMapping ,
237
- collections .abc .Mapping ,
238
- collections .abc .MutableMapping ,
239
- ):
240
- return dict
241
-
242
- if get_origin (annotation ) in (
243
- list ,
244
- Sequence ,
245
- collections .abc .Sequence ,
246
- ):
247
- return list
248
-
249
- if get_origin (annotation ) is Literal :
250
- return tuple ({type (arg ) for arg in get_args (annotation )})
251
-
252
- return annotation
253
-
254
- @classmethod
255
- def get_fields (cls ) -> Dict [str , Any ]:
171
+ def get_fields (cls ) -> dict [str , TypeAnnotation ]:
256
172
"""
257
173
Recent python versions removed inheritance for __annotations__
258
174
so we have to implement it explicitly
259
175
"""
260
176
if not hasattr (cls , "__annotations" ):
261
177
annotations = {}
262
178
for base_cls in cls .__bases__ :
263
- annotations .update (base_cls .__annotations__ )
264
- annotations .update (cls .__annotations__ )
179
+ annotations .update (get_type_hints (base_cls ))
180
+ annotations .update (get_type_hints (cls ))
181
+
265
182
cls .__annotations = {
266
- key : type_
183
+ key : TypeAnnotation . parse ( type_ )
267
184
for key , type_ in annotations .items ()
268
185
if not key .startswith ("_" )
269
186
}
0 commit comments