Skip to content

Commit

Permalink
Merge pull request #76 from michalpokusa/routes-refactor
Browse files Browse the repository at this point in the history
`Route` refactor and fixes, CPython example
  • Loading branch information
FoamyGuy authored Jan 29, 2024
2 parents 8e0b86a + e4c05ad commit ec9b06f
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 130 deletions.
151 changes: 65 additions & 86 deletions adafruit_httpserver/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

try:
from typing import Callable, List, Iterable, Union, Tuple, Dict, TYPE_CHECKING
from typing import Callable, Iterable, Union, Tuple, Literal, Dict, TYPE_CHECKING

if TYPE_CHECKING:
from .response import Response
Expand All @@ -23,6 +23,24 @@
class Route:
"""Route definition for different paths, see `adafruit_httpserver.server.Server.route`."""

@staticmethod
def _prepare_path_pattern(path: str, append_slash: bool) -> str:
# Escape all dots
path = re.sub(r"\.", r"\\.", path)

# Replace url parameters with regex groups
path = re.sub(r"<\w+>", r"([^/]+)", path)

# Replace wildcards with corresponding regex
path = path.replace(r"\.\.\.\.", r".+").replace(r"\.\.\.", r"[^/]+")

# Add optional slash at the end if append_slash is True
if append_slash:
path += r"/?"

# Add start and end of string anchors
return f"^{path}$"

def __init__(
self,
path: str = "",
Expand All @@ -33,80 +51,89 @@ def __init__(
) -> None:
self._validate_path(path, append_slash)

self.parameters_names = [
name[1:-1] for name in re.compile(r"/[^<>]*/?").split(path) if name != ""
]
self.path = re.sub(r"<\w+>", r"([^/]+)", path).replace("....", r".+").replace(
"...", r"[^/]+"
) + ("/?" if append_slash else "")
self.path = path
self.methods = (
set(methods) if isinstance(methods, (set, list, tuple)) else set([methods])
)

self.handler = handler
self.parameters_names = [
name[1:-1] for name in re.compile(r"/[^<>]*/?").split(path) if name != ""
]
self.path_pattern = re.compile(self._prepare_path_pattern(path, append_slash))

@staticmethod
def _validate_path(path: str, append_slash: bool) -> None:
if not path.startswith("/"):
raise ValueError("Path must start with a slash.")

if path.endswith("/") and append_slash:
raise ValueError("Cannot use append_slash=True when path ends with /")

if "//" in path:
raise ValueError("Path cannot contain double slashes.")

if "<>" in path:
raise ValueError("All URL parameters must be named.")

if path.endswith("/") and append_slash:
raise ValueError("Cannot use append_slash=True when path ends with /")
if re.search(r"[^/]<[^/]+>|<[^/]+>[^/]", path):
raise ValueError("All URL parameters must be between slashes.")

if re.search(r"[^/.]\.\.\.\.?|\.?\.\.\.[^/.]", path):
raise ValueError("... and .... must be between slashes")

def match(self, other: "Route") -> Tuple[bool, Dict[str, str]]:
if "....." in path:
raise ValueError("Path cannot contain more than 4 dots in a row.")

def matches(
self, method: str, path: str
) -> Union[Tuple[Literal[False], None], Tuple[Literal[True], Dict[str, str]]]:
"""
Checks if the route matches the other route.
Checks if the route matches given ``method`` and ``path``.
If the route contains parameters, it will check if the ``other`` route contains values for
If the route contains parameters, it will check if the ``path`` contains values for
them.
Returns tuple of a boolean and a list of strings. The boolean indicates if the routes match,
and the list contains the values of the url parameters from the ``other`` route.
Returns tuple of a boolean that indicates if the routes matches and a dict containing
values for url parameters.
If the route does not match ``path`` or ``method`` if will return ``None`` instead of dict.
Examples::
route = Route("/example", GET, True)
route = Route("/example", GET, append_slash=True)
other1a = Route("/example", GET)
other1b = Route("/example/", GET)
route.matches(other1a) # True, {}
route.matches(other1b) # True, {}
route.matches(GET, "/example") # True, {}
route.matches(GET, "/example/") # True, {}
other2 = Route("/other-example", GET)
route.matches(other2) # False, {}
route.matches(GET, "/other-example") # False, None
route.matches(POST, "/example/") # False, None
...
route = Route("/example/<parameter>", GET)
other1 = Route("/example/123", GET)
route.matches(other1) # True, {"parameter": "123"}
route.matches(GET, "/example/123") # True, {"parameter": "123"}
other2 = Route("/other-example", GET)
route.matches(other2) # False, {}
route.matches(GET, "/other-example") # False, None
...
route1 = Route("/example/.../something", GET)
other1 = Route("/example/123/something", GET)
route1.matches(other1) # True, {}
route = Route("/example/.../something", GET)
route.matches(GET, "/example/123/something") # True, {}
route2 = Route("/example/..../something", GET)
other2 = Route("/example/123/456/something", GET)
route2.matches(other2) # True, {}
route = Route("/example/..../something", GET)
route.matches(GET, "/example/123/456/something") # True, {}
"""

if not other.methods.issubset(self.methods):
return False, {}
if method not in self.methods:
return False, None

path_match = self.path_pattern.match(path)
if path_match is None:
return False, None

regex_match = re.match(f"^{self.path}$", other.path)
if regex_match is None:
return False, {}
url_parameters_values = path_match.groups()

return True, dict(zip(self.parameters_names, regex_match.groups()))
return True, dict(zip(self.parameters_names, url_parameters_values))

def __repr__(self) -> str:
path = repr(self.path)
Expand Down Expand Up @@ -168,51 +195,3 @@ def route_decorator(func: Callable) -> Route:
return Route(path, methods, func, append_slash=append_slash)

return route_decorator


class _Routes:
"""A collection of routes and their corresponding handlers."""

def __init__(self) -> None:
self._routes: List[Route] = []

def add(self, route: Route):
"""Adds a route and its handler to the collection."""
self._routes.append(route)

def find_handler(self, route: Route) -> Union[Callable["...", "Response"], None]:
"""
Finds a handler for a given route.
If route used URL parameters, the handler will be wrapped to pass the parameters to the
handler.
Example::
@server.route("/example/<my_parameter>", GET)
def route_func(request, my_parameter):
...
request.path == "/example/123" # True
my_parameter == "123" # True
"""
found_route, _route = False, None

for _route in self._routes:
matches, keyword_parameters = _route.match(route)

if matches:
found_route = True
break

if not found_route:
return None

handler = _route.handler

def wrapped_handler(request):
return handler(request, **keyword_parameters)

return wrapped_handler

def __repr__(self) -> str:
return f"_Routes({repr(self._routes)})"
102 changes: 64 additions & 38 deletions adafruit_httpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .methods import GET, HEAD
from .request import Request
from .response import Response, FileResponse
from .route import _Routes, Route
from .route import Route
from .status import BAD_REQUEST_400, UNAUTHORIZED_401, FORBIDDEN_403, NOT_FOUND_404


Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(
self._auths = []
self._buffer = bytearray(1024)
self._timeout = 1
self._routes = _Routes()
self._routes: "List[Route]" = []
self._socket_source = socket_source
self._sock = None
self.headers = Headers()
Expand Down Expand Up @@ -132,7 +132,7 @@ def route_func(request):
"""

def route_decorator(func: Callable) -> Callable:
self._routes.add(Route(path, methods, func, append_slash=append_slash))
self._routes.append(Route(path, methods, func, append_slash=append_slash))
return func

return route_decorator
Expand All @@ -157,8 +157,7 @@ def add_routes(self, routes: List[Route]) -> None:
external_route2,
]}
"""
for route in routes:
self._routes.add(route)
self._routes.extend(routes)

def _verify_can_start(self, host: str, port: int) -> None:
"""Check if the server can be successfully started. Raises RuntimeError if not."""
Expand All @@ -172,7 +171,7 @@ def _verify_can_start(self, host: str, port: int) -> None:
raise RuntimeError(f"Cannot start server on {host}:{port}") from error

def serve_forever(
self, host: str, port: int = 80, *, poll_interval: float = None
self, host: str, port: int = 80, *, poll_interval: float = 0.1
) -> None:
"""
Wait for HTTP requests at the given host and port. Does not return.
Expand All @@ -187,16 +186,14 @@ def serve_forever(

while not self.stopped:
try:
self.poll()
if self.poll() == NO_REQUEST and poll_interval is not None:
sleep(poll_interval)
except KeyboardInterrupt: # Exit on Ctrl-C e.g. during development
self.stop()
return
except Exception: # pylint: disable=broad-except
pass # Ignore exceptions in handler function

if poll_interval is not None:
sleep(poll_interval)

def start(self, host: str, port: int = 80) -> None:
"""
Start the HTTP server at the given host and port. Requires calling
Expand Down Expand Up @@ -234,32 +231,6 @@ def stop(self) -> None:
if self.debug:
_debug_stopped_server(self)

def _receive_request(
self,
sock: Union["SocketPool.Socket", "socket.socket"],
client_address: Tuple[str, int],
) -> Request:
"""Receive bytes from socket until the whole request is received."""

# Receiving data until empty line
header_bytes = self._receive_header_bytes(sock)

# Return if no data received
if not header_bytes:
return None

request = Request(self, sock, client_address, header_bytes)

content_length = int(request.headers.get_directive("Content-Length", 0))
received_body_bytes = request.body

# Receiving remaining body bytes
request.body = self._receive_body_bytes(
sock, received_body_bytes, content_length
)

return request

def _receive_header_bytes(
self, sock: Union["SocketPool.Socket", "socket.socket"]
) -> bytes:
Expand Down Expand Up @@ -296,6 +267,61 @@ def _receive_body_bytes(
raise ex
return received_body_bytes[:content_length]

def _receive_request(
self,
sock: Union["SocketPool.Socket", "socket.socket"],
client_address: Tuple[str, int],
) -> Request:
"""Receive bytes from socket until the whole request is received."""

# Receiving data until empty line
header_bytes = self._receive_header_bytes(sock)

# Return if no data received
if not header_bytes:
return None

request = Request(self, sock, client_address, header_bytes)

content_length = int(request.headers.get_directive("Content-Length", 0))
received_body_bytes = request.body

# Receiving remaining body bytes
request.body = self._receive_body_bytes(
sock, received_body_bytes, content_length
)

return request

def _find_handler( # pylint: disable=cell-var-from-loop
self, method: str, path: str
) -> Union[Callable[..., "Response"], None]:
"""
Finds a handler for a given route.
If route used URL parameters, the handler will be wrapped to pass the parameters to the
handler.
Example::
@server.route("/example/<my_parameter>", GET)
def route_func(request, my_parameter):
...
request.path == "/example/123" # True
my_parameter == "123" # True
"""
for route in self._routes:
route_matches, url_parameters = route.matches(method, path)

if route_matches:

def wrapped_handler(request):
return route.handler(request, **url_parameters)

return wrapped_handler

return None

def _handle_request(
self, request: Request, handler: Union[Callable, None]
) -> Union[Response, None]:
Expand Down Expand Up @@ -371,8 +397,8 @@ def poll(self) -> str:
conn.close()
return CONNECTION_TIMED_OUT

# Find a handler for the route
handler = self._routes.find_handler(Route(request.path, request.method))
# Find a route that matches the request's method and path and get its handler
handler = self._find_handler(request.method, request.path)

# Handle the request
response = self._handle_request(request, handler)
Expand Down
Loading

0 comments on commit ec9b06f

Please sign in to comment.