Skip to content

Commit 6fe5e02

Browse files
committed
Refactor options parsing to validate complex data structures
1 parent 4a764e2 commit 6fe5e02

File tree

7 files changed

+370
-155
lines changed

7 files changed

+370
-155
lines changed

Diff for: poethepoet/config/partition.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
Optional,
77
Sequence,
88
Type,
9+
TypedDict,
910
Union,
1011
)
1112

1213
from ..exceptions import ConfigValidationError
1314
from ..options import NoValue, PoeOptions
15+
from .primitives import EmptyDict, EnvDefault
1416

1517
KNOWN_SHELL_INTERPRETERS = (
1618
"posix",
@@ -24,6 +26,14 @@
2426
)
2527

2628

29+
class IncludeItem(TypedDict):
30+
path: str
31+
cwd: str
32+
33+
34+
IncludeItem.__optional_keys__ = frozenset({"cwd"})
35+
36+
2737
class ConfigPartition:
2838
options: PoeOptions
2939
full_config: Mapping[str, Any]
@@ -74,9 +84,6 @@ def get(self, key: str, default: Any = NoValue):
7484
return self.options.get(key, default)
7585

7686

77-
EmptyDict: Mapping = MappingProxyType({})
78-
79-
8087
class ProjectConfig(ConfigPartition):
8188
is_primary = True
8289

@@ -88,10 +95,10 @@ class ConfigOptions(PoeOptions):
8895
default_task_type: str = "cmd"
8996
default_array_task_type: str = "sequence"
9097
default_array_item_task_type: str = "ref"
91-
env: Mapping[str, str] = EmptyDict
98+
env: Mapping[str, Union[str, EnvDefault]] = EmptyDict
9299
envfile: Union[str, Sequence[str]] = tuple()
93100
executor: Mapping[str, str] = MappingProxyType({"type": "auto"})
94-
include: Sequence[str] = tuple()
101+
include: Union[str, Sequence[str], Sequence[IncludeItem]] = tuple()
95102
poetry_command: str = "poe"
96103
poetry_hooks: Mapping[str, str] = EmptyDict
97104
shell_interpreter: Union[str, Sequence[str]] = "posix"

Diff for: poethepoet/config/primitives.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from types import MappingProxyType
2+
from typing import Mapping, TypedDict
3+
4+
EmptyDict: Mapping = MappingProxyType({})
5+
6+
7+
class EnvDefault(TypedDict):
8+
default: str

Diff for: poethepoet/options.py renamed to poethepoet/options/__init__.py

+24-107
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,10 @@
1-
import collections
1+
from __future__ import annotations
2+
23
from keyword import iskeyword
3-
from typing import (
4-
Any,
5-
Dict,
6-
List,
7-
Literal,
8-
Mapping,
9-
MutableMapping,
10-
Optional,
11-
Sequence,
12-
Tuple,
13-
Type,
14-
Union,
15-
get_args,
16-
get_origin,
17-
)
18-
19-
from .exceptions import ConfigValidationError
4+
from typing import Any, Mapping, Sequence, get_type_hints
5+
6+
from ..exceptions import ConfigValidationError
7+
from .annotations import TypeAnnotation
208

219
NoValue = object()
2210

@@ -26,7 +14,7 @@ class PoeOptions:
2614
A special kind of config object that parses options ...
2715
"""
2816

29-
__annotations: Dict[str, Type]
17+
__annotations: dict[str, TypeAnnotation]
3018

3119
def __init__(self, **options: Any):
3220
for key in self.get_fields():
@@ -61,13 +49,13 @@ def __getattr__(self, name: str):
6149
@classmethod
6250
def parse(
6351
cls,
64-
source: Union[Mapping[str, Any], list],
52+
source: Mapping[str, Any] | list,
6553
strict: bool = True,
6654
extra_keys: Sequence[str] = tuple(),
6755
):
6856
config_keys = {
69-
key[:-1] if key.endswith("_") and iskeyword(key[:-1]) else key: vtype
70-
for key, vtype in cls.get_fields().items()
57+
key[:-1] if key.endswith("_") and iskeyword(key[:-1]) else key: type_
58+
for key, type_ in cls.get_fields().items()
7159
}
7260
if strict:
7361
for index, item in enumerate(cls.normalize(source, strict)):
@@ -110,29 +98,8 @@ def _parse_value(
11098
return value_type.parse(value, strict=strict)
11199

112100
if strict:
113-
expected_type: Union[Type, Tuple[Type, ...]] = cls._type_of(value_type)
114-
if not isinstance(value, expected_type):
115-
# Try format expected_type nicely in the error message
116-
if not isinstance(expected_type, tuple):
117-
expected_type = (expected_type,)
118-
formatted_type = " | ".join(
119-
type_.__name__ for type_ in expected_type if type_ is not type(None)
120-
)
121-
raise ConfigValidationError(
122-
f"Option {key!r} should have a value of type: {formatted_type}",
123-
index=index,
124-
)
125-
126-
annotation = cls.get_annotation(key)
127-
if get_origin(annotation) is Literal:
128-
allowed_values = get_args(annotation)
129-
if value not in allowed_values:
130-
raise ConfigValidationError(
131-
f"Option {key!r} must be one of {allowed_values!r}",
132-
index=index,
133-
)
134-
135-
# TODO: validate list/dict contents
101+
for error_msg in value_type.validate((key,), value):
102+
raise ConfigValidationError(error_msg, index=index)
136103

137104
return value
138105

@@ -171,43 +138,25 @@ def get(self, key: str, default: Any = NoValue) -> Any:
171138
if default is NoValue:
172139
# Fallback to getting getting the zero value for the type of this attribute
173140
# e.g. 0, False, empty list, empty dict, etc
174-
return self.__get_zero_value(key)
141+
annotation = self.get_fields().get(self._resolve_key(key))
142+
assert annotation
143+
return annotation.zero_value()
175144

176145
return default
177146

178-
def __get_zero_value(self, key: str):
179-
type_of_attr = self.type_of(key)
180-
if isinstance(type_of_attr, tuple):
181-
if type(None) in type_of_attr:
182-
# Optional types default to None
183-
return None
184-
type_of_attr = type_of_attr[0]
185-
assert type_of_attr
186-
return type_of_attr()
187-
188147
def __is_optional(self, key: str):
189-
# TODO: precache optional options keys?
190-
type_of_attr = self.type_of(key)
191-
if isinstance(type_of_attr, tuple):
192-
return type(None) in type_of_attr
193-
return False
148+
annotation = self.get_fields().get(self._resolve_key(key))
149+
assert annotation
150+
return annotation.is_optional
194151

195-
def update(self, options_dict: Dict[str, Any]):
152+
def update(self, options_dict: dict[str, Any]):
196153
new_options_dict = {}
197154
for key in self.get_fields().keys():
198155
if key in options_dict:
199156
new_options_dict[key] = options_dict[key]
200157
elif hasattr(self, key):
201158
new_options_dict[key] = getattr(self, key)
202159

203-
@classmethod
204-
def type_of(cls, key: str) -> Optional[Union[Type, Tuple[Type, ...]]]:
205-
return cls._type_of(cls.get_annotation(key))
206-
207-
@classmethod
208-
def get_annotation(cls, key: str) -> Optional[Type]:
209-
return cls.get_fields().get(cls._resolve_key(key))
210-
211160
@classmethod
212161
def _resolve_key(cls, key: str) -> str:
213162
"""
@@ -219,51 +168,19 @@ def _resolve_key(cls, key: str) -> str:
219168
return key
220169

221170
@classmethod
222-
def _type_of(cls, annotation: Any) -> Union[Type, Tuple[Type, ...]]:
223-
if get_origin(annotation) is Union:
224-
result: List[Type] = []
225-
for component in get_args(annotation):
226-
component_type = cls._type_of(component)
227-
if isinstance(component_type, tuple):
228-
result.extend(component_type)
229-
else:
230-
result.append(component_type)
231-
return tuple(result)
232-
233-
if get_origin(annotation) in (
234-
dict,
235-
Mapping,
236-
MutableMapping,
237-
collections.abc.Mapping,
238-
collections.abc.MutableMapping,
239-
):
240-
return dict
241-
242-
if get_origin(annotation) in (
243-
list,
244-
Sequence,
245-
collections.abc.Sequence,
246-
):
247-
return list
248-
249-
if get_origin(annotation) is Literal:
250-
return tuple({type(arg) for arg in get_args(annotation)})
251-
252-
return annotation
253-
254-
@classmethod
255-
def get_fields(cls) -> Dict[str, Any]:
171+
def get_fields(cls) -> dict[str, TypeAnnotation]:
256172
"""
257173
Recent python versions removed inheritance for __annotations__
258174
so we have to implement it explicitly
259175
"""
260176
if not hasattr(cls, "__annotations"):
261177
annotations = {}
262178
for base_cls in cls.__bases__:
263-
annotations.update(base_cls.__annotations__)
264-
annotations.update(cls.__annotations__)
179+
annotations.update(get_type_hints(base_cls))
180+
annotations.update(get_type_hints(cls))
181+
265182
cls.__annotations = {
266-
key: type_
183+
key: TypeAnnotation.parse(type_)
267184
for key, type_ in annotations.items()
268185
if not key.startswith("_")
269186
}

0 commit comments

Comments
 (0)