Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing_extensions import ParamSpec

from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
Expand Down Expand Up @@ -96,7 +96,7 @@ def build_middleware_stack(self) -> ASGIApp:

app = self.router
for cls, args, kwargs in reversed(middleware):
app = cls(app=app, *args, **kwargs)
app = cls(app, *args, **kwargs)
return app

@property
Expand All @@ -123,7 +123,7 @@ def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:

def add_middleware(
self,
middleware_class: type[_MiddlewareClass[P]],
middleware_class: _MiddlewareFactory[P],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
Expand Down
13 changes: 6 additions & 7 deletions starlette/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@
else: # pragma: no cover
from typing_extensions import ParamSpec

from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp

P = ParamSpec("P")


class _MiddlewareClass(Protocol[P]):
def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: ... # pragma: no cover

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ... # pragma: no cover
class _MiddlewareFactory(Protocol[P]):
def __call__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover


class Middleware:
def __init__(
self,
cls: type[_MiddlewareClass[P]],
cls: _MiddlewareFactory[P],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
Expand All @@ -38,5 +36,6 @@ def __repr__(self) -> str:
class_name = self.__class__.__name__
args_strings = [f"{value!r}" for value in self.args]
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
name = getattr(self.cls, "__name__", "")
args_repr = ", ".join([name] + args_strings + option_strings)
return f"{class_name}({args_repr})"
6 changes: 3 additions & 3 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __init__(

if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.app = cls(self.app, *args, **kwargs)

if methods is None:
self.methods = None
Expand Down Expand Up @@ -328,7 +328,7 @@ def __init__(

if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.app = cls(self.app, *args, **kwargs)

self.path_regex, self.path_format, self.param_convertors = compile_path(path)

Expand Down Expand Up @@ -388,7 +388,7 @@ def __init__(
self.app = self._base_app
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.app = cls(self.app, *args, **kwargs)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")

Expand Down
4 changes: 2 additions & 2 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import ClientDisconnect, Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
Expand Down Expand Up @@ -232,7 +232,7 @@ async def dispatch(
)
def test_contextvars(
test_client_factory: TestClientFactory,
middleware_cls: type[_MiddlewareClass[Any]],
middleware_cls: _MiddlewareFactory[Any],
) -> None:
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
Expand Down
44 changes: 44 additions & 0 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
from contextlib import asynccontextmanager
from pathlib import Path
Expand Down Expand Up @@ -533,6 +535,48 @@ def get_app() -> ASGIApp:
assert SimpleInitializableMiddleware.counter == 2


def test_middleware_args(test_client_factory: TestClientFactory) -> None:
calls: list[str] = []

class MiddlewareWithArgs:
def __init__(self, app: ASGIApp, arg: str) -> None:
self.app = app
self.arg = arg

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
calls.append(self.arg)
await self.app(scope, receive, send)

app = Starlette()
app.add_middleware(MiddlewareWithArgs, "foo")
app.add_middleware(MiddlewareWithArgs, "bar")

with test_client_factory(app):
pass

assert calls == ["bar", "foo"]


def test_middleware_factory(test_client_factory: TestClientFactory) -> None:
calls: list[str] = []

def _middleware_factory(app: ASGIApp, arg: str) -> ASGIApp:
async def _app(scope: Scope, receive: Receive, send: Send) -> None:
calls.append(arg)
await app(scope, receive, send)

return _app

app = Starlette()
app.add_middleware(_middleware_factory, arg="foo")
app.add_middleware(_middleware_factory, arg="bar")

with test_client_factory(app):
pass

assert calls == ["bar", "foo"]


def test_lifespan_app_subclass() -> None:
# This test exists to make sure that subclasses of Starlette
# (like FastAPI) are compatible with the types hints for Lifespan
Expand Down