Skip to content

Commit

Permalink
Merge pull request #429 from zanieb/zb/fix-head-close
Browse files Browse the repository at this point in the history
Always attempt to set the `Connection: close` response header
  • Loading branch information
digitalresistor authored Feb 4, 2024
2 parents e9796b1 + edbe6b8 commit f737755
Showing 1 changed file with 37 additions and 36 deletions.
73 changes: 37 additions & 36 deletions src/waitress/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,18 @@ def has_body(self):
or self.status.startswith("304")
)

def set_close_on_finish(self) -> None:
# if headers have not been written yet, tell the remote
# client we are closing the connection
if not self.wrote_header:
connection_close_header = None
for headername, headerval in self.response_headers:
if headername.capitalize() == "Connection":
connection_close_header = headerval.lower()
if connection_close_header is None:
self.response_headers.append(("Connection", "close"))
self.close_on_finish = True

def build_response_header(self):
version = self.version
# Figure out whether the connection should be closed.
Expand All @@ -188,7 +200,6 @@ def build_response_header(self):
content_length_header = None
date_header = None
server_header = None
connection_close_header = None

for headername, headerval in self.response_headers:
headername = "-".join([x.capitalize() for x in headername.split("-")])
Expand All @@ -205,47 +216,43 @@ def build_response_header(self):
if headername == "Server":
server_header = headerval

if headername == "Connection":
connection_close_header = headerval.lower()
# replace with properly capitalized version
response_headers.append((headername, headerval))

# Overwrite the response headers we have with normalized ones
self.response_headers = response_headers

if (
content_length_header is None
and self.content_length is not None
and self.has_body
):
content_length_header = str(self.content_length)
response_headers.append(("Content-Length", content_length_header))

def close_on_finish():
if connection_close_header is None:
response_headers.append(("Connection", "close"))
self.close_on_finish = True
self.response_headers.append(("Content-Length", content_length_header))

if version == "1.0":
if connection == "keep-alive":
if not content_length_header:
close_on_finish()
self.set_close_on_finish()
else:
response_headers.append(("Connection", "Keep-Alive"))
self.response_headers.append(("Connection", "Keep-Alive"))
else:
close_on_finish()
self.set_close_on_finish()

elif version == "1.1":
if connection == "close":
close_on_finish()
self.set_close_on_finish()

if not content_length_header:
# RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
# for any response with a status code of 1xx, 204 or 304.

if self.has_body:
response_headers.append(("Transfer-Encoding", "chunked"))
self.response_headers.append(("Transfer-Encoding", "chunked"))
self.chunked_response = True

if not self.close_on_finish:
close_on_finish()
self.set_close_on_finish()

# under HTTP 1.1 keep-alive is default, no need to set the header
else:
Expand All @@ -257,14 +264,12 @@ def close_on_finish():

if not server_header:
if ident:
response_headers.append(("Server", ident))
self.response_headers.append(("Server", ident))
else:
response_headers.append(("Via", ident or "waitress"))
self.response_headers.append(("Via", ident or "waitress"))

if not date_header:
response_headers.append(("Date", build_http_date(self.start_time)))

self.response_headers = response_headers
self.response_headers.append(("Date", build_http_date(self.start_time)))

first_line = f"HTTP/{self.version} {self.status}"
# NB: sorting headers needs to preserve same-named-header order
Expand Down Expand Up @@ -350,11 +355,7 @@ def execute(self):
status, headers, body = e.to_response(ident)
self.status = status
self.response_headers.extend(headers)
# We need to explicitly tell the remote client we are closing the
# connection, because self.close_on_finish is set, and we are going to
# slam the door in the clients face.
self.response_headers.append(("Connection", "close"))
self.close_on_finish = True
self.set_close_on_finish()
self.content_length = len(body)
self.write(body)

Expand Down Expand Up @@ -388,7 +389,7 @@ def start_response(status, headers, exc_info=None):

self.complete = True

if not status.__class__ is str:
if status.__class__ is not str:
raise AssertionError("status %s is not a string" % status)
if "\n" in status or "\r" in status:
raise ValueError(
Expand All @@ -399,11 +400,11 @@ def start_response(status, headers, exc_info=None):

# Prepare the headers for output
for k, v in headers:
if not k.__class__ is str:
if k.__class__ is not str:
raise AssertionError(
f"Header name {k!r} is not a string in {(k, v)!r}"
)
if not v.__class__ is str:
if v.__class__ is not str:
raise AssertionError(
f"Header value {v!r} is not a string in {(k, v)!r}"
)
Expand Down Expand Up @@ -478,14 +479,14 @@ def start_response(status, headers, exc_info=None):
# close the connection so the client isn't sitting around
# waiting for more data when there are too few bytes
# to service content-length
# unless it's a HEAD request in which case we don't expect
# to return any bytes regardless of the content length
self.close_on_finish = True
self.logger.warning(
"application returned too few bytes (%s) "
"for specified Content-Length (%s) via app_iter"
% (self.content_bytes_written, cl),
)
self.set_close_on_finish()
if self.request.command != "HEAD":
self.logger.warning(
"application returned too few bytes (%s) "
"for specified Content-Length (%s) via app_iter",
self.content_bytes_written,
cl,
)
finally:
if can_close_app_iter and hasattr(app_iter, "close"):
app_iter.close()
Expand Down

0 comments on commit f737755

Please sign in to comment.