diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 36ba9f93f..cab1f9b23 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -162,10 +162,11 @@ async def send(self, message: Message, send: Send, request_headers: Headers) -> headers.update(self.simple_headers) origin = request_headers["Origin"] has_cookie = "cookie" in request_headers + has_authorization = "authorization" in request_headers - # If request includes any cookie headers, then we must respond + # If request includes any cookie headers or authorization header, then we must respond # with the specific origin instead of '*'. - if self.allow_all_origins and has_cookie: + if self.allow_all_origins and (has_cookie or has_authorization): self.allow_explicit_origin(headers, origin) # If we only allow specific origins, then we have to mirror back diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index cbee7d6e7..db7ccb864 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -52,7 +52,7 @@ def homepage(request: Request) -> PlainTextResponse: assert response.headers["access-control-expose-headers"] == "X-Status" assert response.headers["access-control-allow-credentials"] == "true" - # Test standard credentialed response + # Test standard credentialed response with cookies headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} response = client.get("/", headers=headers) assert response.status_code == 200 @@ -61,6 +61,15 @@ def homepage(request: Request) -> PlainTextResponse: assert response.headers["access-control-expose-headers"] == "X-Status" assert response.headers["access-control-allow-credentials"] == "true" + # Test standard credentialed response with Authorization header + headers = {"Origin": "https://example.org", "Authorization": "Bearer token"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-expose-headers"] == "X-Status" + assert response.headers["access-control-allow-credentials"] == "true" + # Test non-CORS response response = client.get("/") assert response.status_code == 200 @@ -416,7 +425,7 @@ def homepage(request: Request) -> PlainTextResponse: ) client = test_client_factory(app) - # Test credentialed request + # Test credentialed request with cookie headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} response = client.get("/", headers=headers) assert response.status_code == 200 @@ -424,6 +433,14 @@ def homepage(request: Request) -> PlainTextResponse: assert response.headers["access-control-allow-origin"] == "https://example.org" assert "access-control-allow-credentials" not in response.headers + # Test credentialed request with Authorization header + headers = {"Origin": "https://example.org", "Authorization": "Bearer token"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert "access-control-allow-credentials" not in response.headers + def test_cors_vary_header_defaults_to_origin( test_client_factory: TestClientFactory, @@ -478,6 +495,11 @@ def homepage(request: Request) -> PlainTextResponse: assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding, Origin" + # Test with Authorization header + response = client.get("/", headers={"Authorization": "Bearer token", "Origin": "https://someplace.org"}) + assert response.status_code == 200 + assert response.headers["vary"] == "Accept-Encoding, Origin" + def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard( test_client_factory: TestClientFactory,