Skip to content

Commit

Permalink
Add support for binding Unix sockets in Linux's abstract namespace. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Dadeos-Menlo authored Jul 9, 2024
1 parent a30a606 commit 58a5482
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 31 deletions.
26 changes: 16 additions & 10 deletions tornado/netutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,23 @@ def bind_unix_socket(
# Hurd doesn't support SO_REUSEADDR
raise
sock.setblocking(False)
try:
st = os.stat(file)
except FileNotFoundError:
pass
else:
if stat.S_ISSOCK(st.st_mode):
os.remove(file)
# File names comprising of an initial null-byte denote an abstract
# namespace, on Linux, and therefore are not subject to file system
# orientated processing.
if not file.startswith("\0"):
try:
st = os.stat(file)
except FileNotFoundError:
pass
else:
raise ValueError("File %s exists and is not a socket", file)
sock.bind(file)
os.chmod(file, mode)
if stat.S_ISSOCK(st.st_mode):
os.remove(file)
else:
raise ValueError("File %s exists and is not a socket", file)
sock.bind(file)
os.chmod(file, mode)
else:
sock.bind(file)
sock.listen(backlog)
return sock

Expand Down
67 changes: 46 additions & 21 deletions tornado/test/httpserver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import textwrap
import unittest
import urllib.parse
import uuid
from io import BytesIO

import typing
Expand Down Expand Up @@ -813,10 +814,6 @@ def test_manual_protocol(self):
self.assertEqual(self.fetch_json("/")["protocol"], "https")


@unittest.skipIf(
not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin",
"unix sockets not supported on this platform",
)
class UnixSocketTest(AsyncTestCase):
"""HTTPServers can listen on Unix sockets too.
Expand All @@ -828,44 +825,72 @@ class UnixSocketTest(AsyncTestCase):
an HTTP client, so we have to test this by hand.
"""

address = ""

def setUp(self):
if type(self) is UnixSocketTest:
raise unittest.SkipTest("abstract base class")
super().setUp()
self.tmpdir = tempfile.mkdtemp()
self.sockfile = os.path.join(self.tmpdir, "test.sock")
sock = netutil.bind_unix_socket(self.sockfile)
app = Application([("/hello", HelloWorldRequestHandler)])
self.server = HTTPServer(app)
self.server.add_socket(sock)
self.stream = IOStream(socket.socket(socket.AF_UNIX))
self.io_loop.run_sync(lambda: self.stream.connect(self.sockfile))
self.server.add_socket(netutil.bind_unix_socket(self.address))

def tearDown(self):
self.stream.close()
self.io_loop.run_sync(self.server.close_all_connections)
self.server.stop()
shutil.rmtree(self.tmpdir)
super().tearDown()

@gen_test
def test_unix_socket(self):
self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
response = yield self.stream.read_until(b"\r\n")
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
header_data = yield self.stream.read_until(b"\r\n\r\n")
headers = HTTPHeaders.parse(header_data.decode("latin1"))
body = yield self.stream.read_bytes(int(headers["Content-Length"]))
self.assertEqual(body, b"Hello world")
with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream:
stream.connect(self.address)
stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
response = yield stream.read_until(b"\r\n")
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
header_data = yield stream.read_until(b"\r\n\r\n")
headers = HTTPHeaders.parse(header_data.decode("latin1"))
body = yield stream.read_bytes(int(headers["Content-Length"]))
self.assertEqual(body, b"Hello world")

@gen_test
def test_unix_socket_bad_request(self):
# Unix sockets don't have remote addresses so they just return an
# empty string.
with ExpectLog(gen_log, "Malformed HTTP message from", level=logging.INFO):
self.stream.write(b"garbage\r\n\r\n")
response = yield self.stream.read_until_close()
with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream:
stream.connect(self.address)
stream.write(b"garbage\r\n\r\n")
response = yield stream.read_until_close()
self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")


@unittest.skipIf(
not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin",
"unix sockets not supported on this platform",
)
class UnixSocketTestAbstract(UnixSocketTest):

def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.address = os.path.join(self.tmpdir, "test.sock")
super().setUp()

def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmpdir)


@unittest.skipIf(
not (hasattr(socket, "AF_UNIX") and sys.platform.startswith("linux")),
"abstract namespace unix sockets not supported on this platform",
)
class UnixSocketTestFile(UnixSocketTest):

def setUp(self):
self.address = "\0" + uuid.uuid4().hex
super().setUp()


class KeepAliveTest(AsyncHTTPTestCase):
"""Tests various scenarios for HTTP 1.1 keep-alive support.
Expand Down

0 comments on commit 58a5482

Please sign in to comment.