Skip to content
29 changes: 17 additions & 12 deletions starlette/middleware/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,32 @@ def __init__(
if allow_origin_regex is not None:
compiled_allow_origin_regex = re.compile(allow_origin_regex)

allow_all_origins = "*" in allow_origins
allow_all_headers = "*" in allow_headers
preflight_explicit_allow_origin = not allow_all_origins or allow_credentials

simple_headers = {}
if "*" in allow_origins:
if allow_all_origins:
simple_headers["Access-Control-Allow-Origin"] = "*"
if allow_credentials:
simple_headers["Access-Control-Allow-Credentials"] = "true"
if expose_headers:
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)

preflight_headers = {}
if "*" in allow_origins:
preflight_headers["Access-Control-Allow-Origin"] = "*"
else:
if preflight_explicit_allow_origin:
# The origin value will be set in preflight_response() if it is allowed.
preflight_headers["Vary"] = "Origin"
else:
preflight_headers["Access-Control-Allow-Origin"] = "*"
preflight_headers.update(
{
"Access-Control-Allow-Methods": ", ".join(allow_methods),
"Access-Control-Max-Age": str(max_age),
}
)
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
if allow_headers and "*" not in allow_headers:
if allow_headers and not allow_all_headers:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't allow_headers redundant here since it's never an empty list? This could be reduced to this, right?:

if not allow_all_headers:
    ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it looks unused.

Also a bit concerned about headers with mixed case in the Access-Control-Allow-Headers header, but if it is an issue it's a separate one.

preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
if allow_credentials:
preflight_headers["Access-Control-Allow-Credentials"] = "true"
Expand All @@ -59,8 +64,10 @@ def __init__(
self.allow_origins = allow_origins
self.allow_methods = allow_methods
self.allow_headers = [h.lower() for h in allow_headers]
self.allow_all_origins = "*" in allow_origins
self.allow_all_headers = "*" in allow_headers
self.allow_credentials = allow_credentials
self.allow_all_origins = allow_all_origins
self.allow_all_headers = allow_all_headers
self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
self.allow_origin_regex = compiled_allow_origin_regex
self.simple_headers = simple_headers
self.preflight_headers = preflight_headers
Expand Down Expand Up @@ -105,11 +112,9 @@ def preflight_response(self, request_headers: Headers) -> Response:
failures = []

if self.is_allowed_origin(origin=requested_origin):
if not self.allow_all_origins:
# If self.allow_all_origins is True, then the
# "Access-Control-Allow-Origin" header is already set to "*".
# If we only allow specific origins, then we have to mirror back
# the Origin header in the response.
if self.preflight_explicit_allow_origin:
# The "else" case is already accounted for in self.preflight_headers
# and the value would be "*".
headers["Access-Control-Allow-Origin"] = requested_origin
else:
failures.append("origin")
Expand Down
111 changes: 110 additions & 1 deletion tests/middleware/test_cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,62 @@ def homepage(request):

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Example",
}
response = client.options("/", headers=headers)
assert response.status_code == 200
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-headers"] == "X-Example"
assert response.headers["access-control-allow-credentials"] == "true"
assert response.headers["vary"] == "Origin"

# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert response.headers["access-control-allow-credentials"] == "true"

# Test standard credentialed response
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert response.headers["access-control-allow-credentials"] == "true"

# Test non-CORS response
response = client.get("/")
assert response.status_code == 200
assert response.text == "Homepage"
assert "access-control-allow-origin" not in response.headers


def test_cors_allow_all_except_credentials():
app = Starlette()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_headers=["*"],
allow_methods=["*"],
expose_headers=["X-Status"],
)

@app.route("/")
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
Expand All @@ -33,6 +89,8 @@ def homepage(request):
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-allow-headers"] == "X-Example"
assert "access-control-allow-credentials" not in response.headers
assert "vary" not in response.headers

# Test standard response
headers = {"Origin": "https://example.org"}
Expand All @@ -41,6 +99,7 @@ def homepage(request):
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert "access-control-allow-credentials" not in response.headers

# Test non-CORS response
response = client.get("/")
Expand Down Expand Up @@ -77,13 +136,15 @@ def homepage(request):
assert response.headers["access-control-allow-headers"] == (
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
)
assert "access-control-allow-credentials" not in response.headers

# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert "access-control-allow-credentials" not in response.headers

# Test non-CORS response
response = client.get("/")
Expand Down Expand Up @@ -116,6 +177,38 @@ def homepage(request):
response = client.options("/", headers=headers)
assert response.status_code == 400
assert response.text == "Disallowed CORS origin, method, headers"
assert "access-control-allow-origin" not in response.headers


def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed():
app = Starlette()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
allow_credentials=True,
)

@app.route("/")
def homepage(request):
return # pragma: no cover

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "POST",
}
response = client.options(
"/",
headers=headers,
)
assert response.status_code == 200
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"
assert response.headers["vary"] == "Origin"


def test_cors_preflight_allow_all_methods():
Expand Down Expand Up @@ -175,6 +268,7 @@ def test_cors_allow_origin_regex():
CORSMiddleware,
allow_headers=["X-Example", "Content-Type"],
allow_origin_regex="https://.*",
allow_credentials=True,
)

@app.route("/")
Expand All @@ -189,8 +283,17 @@ def homepage(request):
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"

# Test diallowed standard response
# Test standard credentialed response
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"

# Test disallowed standard response
# Note that enforcement is a browser concern. The disallowed-ness is reflected
# in the lack of an "access-control-allow-origin" header in the response.
headers = {"Origin": "http://example.org"}
Expand All @@ -212,6 +315,7 @@ def homepage(request):
assert response.headers["access-control-allow-headers"] == (
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
)
assert response.headers["access-control-allow-credentials"] == "true"

# Test disallowed pre-flight response
headers = {
Expand Down Expand Up @@ -249,6 +353,7 @@ def homepage(request):
response.headers["access-control-allow-origin"]
== "https://subdomain.example.org"
)
assert "access-control-allow-credentials" not in response.headers

# Test diallowed standard response
headers = {"Origin": "https://subdomain.example.org.hacker.com"}
Expand All @@ -275,6 +380,7 @@ def homepage(request):
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert "access-control-allow-credentials" not in response.headers


def test_cors_vary_header_defaults_to_origin():
Expand Down Expand Up @@ -365,11 +471,14 @@ def homepage(request):
client = TestClient(app)
response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in response.headers

response = client.get(
"/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
)
assert response.headers["access-control-allow-origin"] == "https://someplace.org"
assert "access-control-allow-credentials" not in response.headers

response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in response.headers