Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Correct type hints for parse_string(s)_from_args. #10137

Merged
merged 9 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10137.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `parse_strings_from_args` for parsing an array from query parameters.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ files =
synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/well_known_resolver.py,
synapse/http/matrixfederationclient.py,
synapse/http/servlet.py,
synapse/http/server.py,
synapse/http/site.py,
synapse/logging,
Expand Down
179 changes: 111 additions & 68 deletions synapse/http/servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
""" This module contains base REST classes for constructing REST servlets. """

import logging
from typing import Iterable, List, Optional, Union, overload
from typing import Dict, Iterable, List, Optional, overload

from typing_extensions import Literal

from twisted.web.server import Request

from synapse.api.errors import Codes, SynapseError
from synapse.util import json_decoder

Expand Down Expand Up @@ -108,13 +110,66 @@ def parse_boolean_from_args(args, name, default=None, required=False):
return default


@overload
def parse_bytes_from_args(
args: Dict[bytes, List[bytes]],
name: str,
default: Literal[None] = None,
required: Literal[True] = True,
) -> bytes:
...


@overload
def parse_bytes_from_args(
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[bytes] = None,
required: bool = False,
) -> Optional[bytes]:
...


def parse_bytes_from_args(
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[bytes] = None,
required: bool = False,
) -> Optional[bytes]:
"""
Parse a string parameter as bytes from the request query string.

Args:
args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
Returns:
Bytes or the default value.

Raises:
SynapseError if the parameter is absent and required.
"""
name_bytes = name.encode("ascii")

if name_bytes in args:
return args[name_bytes][0]
elif required:
message = "Missing string query parameter %s" % (name,)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)

return default


def parse_string(
request,
name: Union[bytes, str],
request: Request,
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: Optional[str] = "ascii",
encoding: str = "ascii",
):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a return value here cascades to a lot of changes. We should do it, but I'd prefer to not do it in this PR.

"""
Parse a string parameter from the request query string.
Expand All @@ -125,66 +180,65 @@ def parse_string(
Args:
request: the twisted HTTP request.
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
default: value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values: List of allowed values for the
string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
encoding : The encoding to decode the string content with.
encoding: The encoding to decode the string content with.

Returns:
A string value or the default. Unicode if encoding
was given, bytes otherwise.
A string value or the default.

Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
return parse_string_from_args(
request.args, name, default, required, allowed_values, encoding
args, name, default, required, allowed_values, encoding
)


def _parse_string_value(
value: Union[str, bytes],
value: bytes,
allowed_values: Optional[Iterable[str]],
name: str,
encoding: Optional[str],
) -> Union[str, bytes]:
if encoding:
try:
value = value.decode(encoding)
except ValueError:
raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
encoding: str,
) -> str:
try:
value_str = value.decode(encoding)
except ValueError:
raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))

if allowed_values is not None and value not in allowed_values:
if allowed_values is not None and value_str not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name,
", ".join(repr(v) for v in allowed_values),
)
raise SynapseError(400, message)
else:
return value
return value_str


@overload
def parse_strings_from_args(
args: List[str],
name: Union[bytes, str],
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[List[str]] = None,
required: bool = False,
required: Literal[True] = True,
allowed_values: Optional[Iterable[str]] = None,
encoding: Literal[None] = None,
) -> Optional[List[bytes]]:
encoding: str = "ascii",
) -> List[str]:
...


@overload
def parse_strings_from_args(
args: List[str],
name: Union[bytes, str],
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[List[str]] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
Expand All @@ -194,83 +248,71 @@ def parse_strings_from_args(


def parse_strings_from_args(
args: List[str],
name: Union[bytes, str],
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[List[str]] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: Optional[str] = "ascii",
) -> Optional[List[Union[bytes, str]]]:
encoding: str = "ascii",
) -> Optional[List[str]]:
"""
Parse a string parameter from the request query string list.

If encoding is not None, the content of the query param will be
decoded to Unicode using the encoding, otherwise it will be encoded
The content of the query param will be decoded to Unicode using the encoding.

Args:
args: the twisted HTTP request.args list.
args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
required : whether to raise a 400 SynapseError if the
default: value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values (list[bytes|unicode]): List of allowed values for the
string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
allowed_values: List of allowed values for the
string, or None if any value is allowed, defaults to None.
encoding: The encoding to decode the string content with.

Returns:
A string value or the default. Unicode if encoding
was given, bytes otherwise.
A string value or the default.

Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
name_bytes = name.encode("ascii")

if not isinstance(name, bytes):
name = name.encode("ascii")

if name in args:
values = args[name]
if name_bytes in args:
values = args[name_bytes]

return [
_parse_string_value(value, allowed_values, name=name, encoding=encoding)
for value in values
]
else:
if required:
message = "Missing string query parameter %r" % (name)
message = "Missing string query parameter %r" % (name,)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else:

if encoding and isinstance(default, bytes):
return default.decode(encoding)

return default
return default


def parse_string_from_args(
args: List[str],
name: Union[bytes, str],
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: Optional[str] = "ascii",
) -> Optional[Union[bytes, str]]:
encoding: str = "ascii",
) -> Optional[str]:
"""
Parse the string parameter from the request query string list
and return the first result.

If encoding is not None, the content of the query param will be
decoded to Unicode using the encoding, otherwise it will be encoded
The content of the query param will be decoded to Unicode using the encoding.

Args:
args: the twisted HTTP request.args list.
args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
default: value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values: List of allowed values for the
Expand All @@ -279,8 +321,7 @@ def parse_string_from_args(
encoding: The encoding to decode the string content with.

Returns:
A string value or the default. Unicode if encoding
was given, bytes otherwise.
A string value or the default.

Raises:
SynapseError if the parameter is absent and required, or if the
Expand All @@ -291,12 +332,15 @@ def parse_string_from_args(
strings = parse_strings_from_args(
args,
name,
default=[default],
default=[default] if default is not None else None,
required=required,
allowed_values=allowed_values,
encoding=encoding,
)

if strings is None:
return None

return strings[0]


Expand Down Expand Up @@ -388,9 +432,8 @@ class attribute containing a pre-compiled regular expression. The automatic

def register(self, http_server):
""" Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERNS"):
patterns = self.PATTERNS

patterns = getattr(self, "PATTERNS", None)
if patterns:
for method in ("GET", "PUT", "POST", "DELETE"):
if hasattr(self, "on_%s" % (method,)):
servlet_classname = self.__class__.__name__
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ async def on_GET(
limit = parse_integer(request, "limit", default=10)

# picking the API shape for symmetry with /messages
filter_str = parse_string(request, b"filter", encoding="utf-8")
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(
Expand Down
8 changes: 4 additions & 4 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
import re
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional

from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
Expand All @@ -25,6 +25,7 @@
from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
parse_bytes_from_args,
parse_json_object_from_request,
parse_string,
)
Expand Down Expand Up @@ -437,9 +438,8 @@ async def on_GET(
finish_request(request)
return

client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None
)
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
sso_url = await self._sso_handler.handle_redirect_request(
request,
client_redirect_url,
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/v1/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ async def on_GET(self, request, room_id):
self.store, request, default_limit=10
)
as_client_event = b"raw" not in request.args
filter_str = parse_string(request, b"filter", encoding="utf-8")
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(
Expand Down Expand Up @@ -652,7 +652,7 @@ async def on_GET(self, request, room_id, event_id):
limit = parse_integer(request, "limit", default=10)

# picking the API shape for symmetry with /messages
filter_str = parse_string(request, b"filter", encoding="utf-8")
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(
Expand Down
Loading