Skip to content

Commit

Permalink
Fix add_middleware enum comparison (#1698)
Browse files Browse the repository at this point in the history
Fixes #1697

Because of a wrong comparison against the position `Enum`, middleware
was not actually being added to the stack via `add_middleware`. This PR
fixes this, adds a warning when the middleware position cannot be found,
and adds a test.
  • Loading branch information
RobbeSneyders authored Apr 24, 2023
1 parent 128a8e0 commit 15fe2ed
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 11 deletions.
12 changes: 10 additions & 2 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import dataclasses
import enum
import logging
Expand Down Expand Up @@ -180,7 +181,9 @@ def __init__(
self.app = app
self.lifespan = lifespan
self.middlewares = (
middlewares if middlewares is not None else self.default_middlewares
middlewares
if middlewares is not None
else copy.copy(self.default_middlewares)
)
self.middleware_stack: t.Optional[t.Iterable[ASGIApp]] = None
self.apis: t.List[API] = []
Expand Down Expand Up @@ -223,11 +226,16 @@ def add_middleware(
if isinstance(middleware, partial):
middleware = middleware.func

if middleware == position:
if middleware == position.value:
self.middlewares.insert(
m, t.cast(ASGIApp, partial(middleware_class, **options))
)
break
else:
raise ValueError(
f"Could not insert middleware at position {position.name}. "
f"Please make sure you have a {position.value} in your stack."
)

def _build_middleware_stack(self) -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]:
"""Apply all middlewares to the provided app.
Expand Down
5 changes: 0 additions & 5 deletions tests/api/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import json

import flask


def fix_data(data):
return data.replace(b'\\"', b'"')

Expand Down
40 changes: 36 additions & 4 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import sys
from unittest import mock

import pytest
from connexion.middleware import ConnexionMiddleware
from connexion.middleware import ConnexionMiddleware, MiddlewarePosition
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from starlette.datastructures import MutableHeaders

from conftest import build_app_from_fixture
Expand Down Expand Up @@ -49,3 +47,37 @@ def test_routing_middleware(middleware_app):
assert (
response.headers.get("operation_id") == "fakeapi.hello.post_greeting"
), response.status_code


def test_add_middleware(spec, app_class):
"""Test adding middleware via the `add_middleware` method."""
app = build_app_from_fixture("simple", app_class=app_class, spec_file=spec)
app.add_middleware(TestMiddleware)

app_client = app.test_client()
response = app_client.post("/v1.0/greeting/robbe")

assert (
response.headers.get("operation_id") == "fakeapi.hello.post_greeting"
), response.status_code


def test_position(spec, app_class):
"""Test adding middleware via the `add_middleware` method."""
middlewares = [
middleware
for middleware in ConnexionMiddleware.default_middlewares
if middleware != SwaggerUIMiddleware
]
app = build_app_from_fixture(
"simple", app_class=app_class, spec_file=spec, middlewares=middlewares
)

with pytest.raises(ValueError) as exc_info:
app.add_middleware(TestMiddleware, position=MiddlewarePosition.BEFORE_SWAGGER)

assert (
exc_info.value.args[0]
== f"Could not insert middleware at position BEFORE_SWAGGER. "
f"Please make sure you have a {SwaggerUIMiddleware} in your stack."
)

0 comments on commit 15fe2ed

Please sign in to comment.