diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 5e724d19d..7aae90ec7 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -13,6 +13,13 @@ class Address(typing.NamedTuple): port: int +_KeyType = typing.TypeVar("_KeyType") +# Mapping keys are invariant but their values are covariant since +# you can only read them +# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` +_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True) + + class URL: def __init__( self, @@ -238,32 +245,36 @@ def __str__(self) -> str: return ", ".join(repr(item) for item in self) -class ImmutableMultiDict(typing.Mapping): +class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): + _dict: typing.Dict[_KeyType, _CovariantValueType] + def __init__( self, *args: typing.Union[ - "ImmutableMultiDict", - typing.Mapping, - typing.List[typing.Tuple[typing.Any, typing.Any]], + "ImmutableMultiDict[_KeyType, _CovariantValueType]", + typing.Mapping[_KeyType, _CovariantValueType], + typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]], ], **kwargs: typing.Any, ) -> None: assert len(args) < 2, "Too many arguments." - value = args[0] if args else [] + value: typing.Any = args[0] if args else [] if kwargs: value = ( ImmutableMultiDict(value).multi_items() - + ImmutableMultiDict(kwargs).multi_items() + + ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator] ) if not value: _items: typing.List[typing.Tuple[typing.Any, typing.Any]] = [] elif hasattr(value, "multi_items"): - value = typing.cast(ImmutableMultiDict, value) + value = typing.cast( + ImmutableMultiDict[_KeyType, _CovariantValueType], value + ) _items = list(value.multi_items()) elif hasattr(value, "items"): - value = typing.cast(typing.Mapping, value) + value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value) _items = list(value.items()) else: value = typing.cast( @@ -274,33 +285,28 @@ def __init__( self._dict = {k: v for k, v in _items} self._list = _items - def getlist(self, key: typing.Any) -> typing.List[typing.Any]: + def getlist(self, key: typing.Any) -> typing.List[_CovariantValueType]: return [item_value for item_key, item_value in self._list if item_key == key] - def keys(self) -> typing.KeysView: + def keys(self) -> typing.KeysView[_KeyType]: return self._dict.keys() - def values(self) -> typing.ValuesView: + def values(self) -> typing.ValuesView[_CovariantValueType]: return self._dict.values() - def items(self) -> typing.ItemsView: + def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]: return self._dict.items() - def multi_items(self) -> typing.List[typing.Tuple[str, str]]: + def multi_items(self) -> typing.List[typing.Tuple[_KeyType, _CovariantValueType]]: return list(self._list) - def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: - if key in self._dict: - return self._dict[key] - return default - - def __getitem__(self, key: typing.Any) -> str: + def __getitem__(self, key: _KeyType) -> _CovariantValueType: return self._dict[key] def __contains__(self, key: typing.Any) -> bool: return key in self._dict - def __iter__(self) -> typing.Iterator[typing.Any]: + def __iter__(self) -> typing.Iterator[_KeyType]: return iter(self.keys()) def __len__(self) -> int: @@ -317,7 +323,7 @@ def __repr__(self) -> str: return f"{class_name}({items!r})" -class MultiDict(ImmutableMultiDict): +class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]): def __setitem__(self, key: typing.Any, value: typing.Any) -> None: self.setlist(key, [value]) @@ -377,7 +383,7 @@ def update( self._dict.update(value) -class QueryParams(ImmutableMultiDict): +class QueryParams(ImmutableMultiDict[str, str]): """ An immutable multidict. """ diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 22e377c99..3ba8bbebc 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -220,7 +220,9 @@ def test_url_blank_params(): assert "abc" in q assert "def" in q assert "b" in q - assert len(q.get("abc")) == 0 + val = q.get("abc") + assert val is not None + assert len(val) == 0 assert len(q["a"]) == 3 assert list(q.keys()) == ["a", "abc", "def", "b"] @@ -342,6 +344,7 @@ def test_multidict(): q = MultiDict([("a", "123"), ("a", "456")]) q["a"] = "789" assert q["a"] == "789" + assert q.get("a") == "789" assert q.getlist("a") == ["789"] q = MultiDict([("a", "123"), ("a", "456")])