Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import typing
import warnings

from typing_extensions import ParamSpec

from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
Expand All @@ -15,6 +17,7 @@
from starlette.websockets import WebSocket

AppType = typing.TypeVar("AppType", bound="Starlette")
P = ParamSpec("P")


class Starlette:
Expand Down Expand Up @@ -98,8 +101,8 @@ def build_middleware_stack(self) -> ASGIApp:
)

app = self.router
for cls, options in reversed(middleware):
app = cls(app=app, **options)
for cls, args, kwargs in reversed(middleware):
app = cls(app=app, *args, **kwargs)
return app

@property
Expand All @@ -124,10 +127,15 @@ def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
self.router.host(host, app=app, name=name) # pragma: no cover

def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
def add_middleware(
self,
middleware_class: typing.Type[_MiddlewareClass[P]],
*args: P.args,
Comment thread
pawelrubin marked this conversation as resolved.
**kwargs: P.kwargs,
) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self.user_middleware.insert(0, Middleware(middleware_class, **options))
self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs))

def add_exception_handler(
self,
Expand Down
35 changes: 28 additions & 7 deletions starlette/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
import typing
from typing import Any, Iterator, Protocol, Type

from typing_extensions import ParamSpec

from starlette.types import ASGIApp, Receive, Scope, Send

P = ParamSpec("P")


class _MiddlewareClass(Protocol[P]):
def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None:
... # pragma: no cover

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
... # pragma: no cover


class Middleware:
def __init__(self, cls: type, **options: typing.Any) -> None:
def __init__(
self,
cls: Type[_MiddlewareClass[P]],
*args: P.args,
Comment thread
pawelrubin marked this conversation as resolved.
**kwargs: P.kwargs,
) -> None:
self.cls = cls
self.options = options
self.args = args
self.kwargs = kwargs

def __iter__(self) -> typing.Iterator[typing.Any]:
as_tuple = (self.cls, self.options)
def __iter__(self) -> Iterator[Any]:
as_tuple = (self.cls, self.args, self.kwargs)
return iter(as_tuple)

def __repr__(self) -> str:
class_name = self.__class__.__name__
option_strings = [f"{key}={value!r}" for key, value in self.options.items()]
args_repr = ", ".join([self.cls.__name__] + option_strings)
args_strings = [f"{value!r}" for value in self.args]
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
return f"{class_name}({args_repr})"
16 changes: 8 additions & 8 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def __init__(
self.app = endpoint

if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)

if methods is None:
self.methods = None
Expand Down Expand Up @@ -335,8 +335,8 @@ def __init__(
self.app = endpoint

if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)

self.path_regex, self.path_format, self.param_convertors = compile_path(path)

Expand Down Expand Up @@ -404,8 +404,8 @@ def __init__(
self._base_app = Router(routes=routes)
self.app = self._base_app
if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(
self.path + "/{path:path}"
Expand Down Expand Up @@ -672,8 +672,8 @@ def __init__(

self.middleware_stack = self.app
if middleware:
for cls, options in reversed(middleware):
self.middleware_stack = cls(self.middleware_stack, **options)
for cls, args, kwargs in reversed(middleware):
self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)

async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
Expand Down
6 changes: 3 additions & 3 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import contextvars
from contextlib import AsyncExitStack
from typing import AsyncGenerator, Awaitable, Callable, List, Union
from typing import Any, AsyncGenerator, Awaitable, Callable, List, Type, Union

import anyio
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
Expand Down Expand Up @@ -196,7 +196,7 @@ async def dispatch(self, request, call_next):
),
],
)
def test_contextvars(test_client_factory, middleware_cls: type):
def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[Any]]):
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
Expand Down
22 changes: 17 additions & 5 deletions tests/middleware/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from starlette.middleware import Middleware
from starlette.types import ASGIApp, Receive, Scope, Send


class CustomMiddleware:
pass
class CustomMiddleware: # pragma: no cover
def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None:
self.app = app
self.foo = foo
self.bar = bar

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)

def test_middleware_repr():
middleware = Middleware(CustomMiddleware)
assert repr(middleware) == "Middleware(CustomMiddleware)"

def test_middleware_repr() -> None:
middleware = Middleware(CustomMiddleware, "foo", bar=123)
assert repr(middleware) == "Middleware(CustomMiddleware, 'foo', bar=123)"


def test_middleware_iter() -> None:
cls, args, kwargs = Middleware(CustomMiddleware, "foo", bar=123)
assert (cls, args, kwargs) == (CustomMiddleware, ("foo",), {"bar": 123})
12 changes: 6 additions & 6 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Callable
from typing import AsyncIterator, Callable

import anyio
import httpx
Expand All @@ -15,7 +15,7 @@
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.types import ASGIApp
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket


Expand Down Expand Up @@ -499,8 +499,8 @@ class NoOpMiddleware:
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, *args: Any):
await self.app(*args)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)

class SimpleInitializableMiddleware:
counter = 0
Expand All @@ -509,8 +509,8 @@ def __init__(self, app: ASGIApp):
self.app = app
SimpleInitializableMiddleware.counter += 1

async def __call__(self, *args: Any):
await self.app(*args)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)

def get_app() -> ASGIApp:
app = Starlette()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from starlette.endpoints import HTTPEndpoint
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.responses import JSONResponse
from starlette.routing import Route, WebSocketRoute
from starlette.websockets import WebSocketDisconnect
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_authentication_redirect(test_client_factory):
assert response.json() == {"authenticated": True, "user": "tomchristie"}


def on_auth_error(request: Request, exc: Exception):
def on_auth_error(request: HTTPConnection, exc: AuthenticationError):
return JSONResponse({"error": str(exc)}, status_code=401)


Expand Down