Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4adca90
add generics type parameters to ImmutableMultiDict
adriangb Jan 29, 2022
7994f31
Merge branch 'master' into generic-immutable-multi-dict
adriangb Jan 30, 2022
bb122b4
Merge branch 'master' into generic-immutable-multi-dict
adriangb Jan 31, 2022
7a9c286
Merge branch 'master' into generic-immutable-multi-dict
adriangb Feb 1, 2022
309112b
Merge branch 'master' into generic-immutable-multi-dict
adriangb Feb 3, 2022
33fa54b
Merge branch 'master' into generic-immutable-multi-dict
adriangb Feb 8, 2022
d18674e
Merge branch 'master' into generic-immutable-multi-dict
adriangb Feb 8, 2022
8149559
Merge branch 'master' into generic-immutable-multi-dict
adriangb Feb 11, 2022
73dfd45
Merge branch 'master' into generic-immutable-multi-dict
adriangb Feb 11, 2022
0213422
Merge branch 'master' into generic-immutable-multi-dict
adriangb Feb 16, 2022
6812f66
Merge branch 'master' into generic-immutable-multi-dict
adriangb Mar 21, 2022
e542551
Merge branch 'master' into generic-immutable-multi-dict
adriangb Apr 22, 2022
64c5398
Merge branch 'master' into generic-immutable-multi-dict
adriangb Apr 24, 2022
5884be8
Merge branch 'master' into generic-immutable-multi-dict
adriangb May 6, 2022
f7eccee
Merge branch 'master' into generic-immutable-multi-dict
adriangb May 22, 2022
2dda0de
rename type variables
adriangb May 23, 2022
932df88
fmt
adriangb May 23, 2022
c4b34ee
Merge remote-tracking branch 'upstream/master' into generic-immutable…
adriangb May 24, 2022
9ed2781
accept an iterable in the constructor
adriangb May 24, 2022
0e92359
remove get
adriangb May 24, 2022
76127b5
add test proving removing get works
adriangb May 24, 2022
6966ba6
make mypy happy
adriangb May 24, 2022
8396e57
Update starlette/datastructures.py
adriangb May 24, 2022
5261a1f
Update starlette/datastructures.py
adriangb May 24, 2022
e7631e8
fix comments
adriangb May 24, 2022
75a5e43
fmt
adriangb May 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ class Address(typing.NamedTuple):
port: int


_KeyType = typing.TypeVar("_KeyType")
# Mapping keys are invariant but their values are covariant since
# you can only read them
# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)


class URL:
def __init__(
self,
Expand Down Expand Up @@ -238,32 +245,36 @@ def __str__(self) -> str:
return ", ".join(repr(item) for item in self)


class ImmutableMultiDict(typing.Mapping):
class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
_dict: typing.Dict[_KeyType, _CovariantValueType]

def __init__(
self,
*args: typing.Union[
"ImmutableMultiDict",
typing.Mapping,
typing.List[typing.Tuple[typing.Any, typing.Any]],
"ImmutableMultiDict[_KeyType, _CovariantValueType]",
typing.Mapping[_KeyType, _CovariantValueType],
typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]],
],
**kwargs: typing.Any,
) -> None:
assert len(args) < 2, "Too many arguments."

value = args[0] if args else []
value: typing.Any = args[0] if args else []
if kwargs:
value = (
ImmutableMultiDict(value).multi_items()
+ ImmutableMultiDict(kwargs).multi_items()
+ ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator]
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't know the returned value of the operator + when there's a sum of two ImmutableMultiDict. That's why we ignore this error here.

)

if not value:
_items: typing.List[typing.Tuple[typing.Any, typing.Any]] = []
elif hasattr(value, "multi_items"):
value = typing.cast(ImmutableMultiDict, value)
value = typing.cast(
ImmutableMultiDict[_KeyType, _CovariantValueType], value
)
_items = list(value.multi_items())
elif hasattr(value, "items"):
value = typing.cast(typing.Mapping, value)
value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
_items = list(value.items())
else:
value = typing.cast(
Expand All @@ -274,33 +285,28 @@ def __init__(
self._dict = {k: v for k, v in _items}
self._list = _items

def getlist(self, key: typing.Any) -> typing.List[typing.Any]:
def getlist(self, key: typing.Any) -> typing.List[_CovariantValueType]:
return [item_value for item_key, item_value in self._list if item_key == key]

def keys(self) -> typing.KeysView:
def keys(self) -> typing.KeysView[_KeyType]:
return self._dict.keys()

def values(self) -> typing.ValuesView:
def values(self) -> typing.ValuesView[_CovariantValueType]:
return self._dict.values()

def items(self) -> typing.ItemsView:
def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
return self._dict.items()

def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
def multi_items(self) -> typing.List[typing.Tuple[_KeyType, _CovariantValueType]]:
return list(self._list)

def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
if key in self._dict:
return self._dict[key]
return default

def __getitem__(self, key: typing.Any) -> str:
def __getitem__(self, key: _KeyType) -> _CovariantValueType:
return self._dict[key]

def __contains__(self, key: typing.Any) -> bool:
return key in self._dict

def __iter__(self) -> typing.Iterator[typing.Any]:
def __iter__(self) -> typing.Iterator[_KeyType]:
return iter(self.keys())

def __len__(self) -> int:
Expand All @@ -317,7 +323,7 @@ def __repr__(self) -> str:
return f"{class_name}({items!r})"


class MultiDict(ImmutableMultiDict):
class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
self.setlist(key, [value])

Expand Down Expand Up @@ -377,7 +383,7 @@ def update(
self._dict.update(value)


class QueryParams(ImmutableMultiDict):
class QueryParams(ImmutableMultiDict[str, str]):
Copy link
Copy Markdown
Contributor Author

@adriangb adriangb May 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make MyPy happy, previously it was picking up str from ImmutableMultiDict but now that ImmutableMultiDict is generic it was complaining about an unknown type.

Copy link
Copy Markdown
Contributor

@florimondmanca florimondmanca May 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also means we now get QueryParams(...).get(...) to be a str, rather than Any. Which is neat and correct, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! There's a lot of downstream work from this (like actually fixing Form), but I wanted to keep this PR as narrowly scoped as possible. Hopefully we can approve and merge those downstream PRs quicker once the basic structure for stronger typing is added via this PR.

"""
An immutable multidict.
"""
Expand Down
5 changes: 4 additions & 1 deletion tests/test_datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def test_url_blank_params():
assert "abc" in q
assert "def" in q
assert "b" in q
assert len(q.get("abc")) == 0
val = q.get("abc")
assert val is not None
assert len(val) == 0
Comment on lines -223 to +225
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to make mypy happy

assert len(q["a"]) == 3
assert list(q.keys()) == ["a", "abc", "def", "b"]

Expand Down Expand Up @@ -342,6 +344,7 @@ def test_multidict():
q = MultiDict([("a", "123"), ("a", "456")])
q["a"] = "789"
assert q["a"] == "789"
assert q.get("a") == "789"
assert q.getlist("a") == ["789"]

q = MultiDict([("a", "123"), ("a", "456")])
Expand Down