|
3 | 3 | import sys
|
4 | 4 | from typing import TYPE_CHECKING
|
5 | 5 |
|
| 6 | +import pytest |
| 7 | + |
6 | 8 | import trio.socket as tsocket
|
7 |
| -from trio import SocketListener |
8 | 9 | from trio._highlevel_open_unix_listeners import (
|
| 10 | + UnixSocketListener, |
9 | 11 | open_unix_listeners,
|
10 | 12 | )
|
11 |
| -from trio.testing import open_stream_to_socket_listener |
| 13 | +from trio._highlevel_socket import SocketStream |
| 14 | + |
| 15 | +assert not TYPE_CHECKING or sys.platform != "win32" |
12 | 16 |
|
13 |
| -assert ( # Skip type checking when on Windows |
14 |
| - sys.platform != "win32" or not TYPE_CHECKING |
| 17 | +skip_if_not_unix = pytest.mark.skipif( |
| 18 | + not hasattr(tsocket, "AF_UNIX"), |
| 19 | + reason="Needs unix socket support", |
15 | 20 | )
|
16 | 21 |
|
17 | 22 |
|
| 23 | +async def open_stream_to_unix_socket_listener( |
| 24 | + socket_listener: UnixSocketListener, |
| 25 | + sockaddr: str, |
| 26 | +) -> SocketStream: |
| 27 | + """Connect to the given :class:`~trio.UnixSocketListener`. |
| 28 | +
|
| 29 | + This is particularly useful in tests when you want to let a server pick |
| 30 | + its own port, and then connect to it:: |
| 31 | +
|
| 32 | + listeners = await trio.open_tcp_listeners(0) |
| 33 | + client = await trio.testing.open_stream_to_socket_listener(listeners[0]) |
| 34 | +
|
| 35 | + Args: |
| 36 | + socket_listener (~trio.UnixSocketListener): The |
| 37 | + :class:`~trio.UnixSocketListener` to connect to. |
| 38 | +
|
| 39 | + Returns: |
| 40 | + SocketStream: a stream connected to the given listener. |
| 41 | +
|
| 42 | + """ |
| 43 | + family = socket_listener.socket.family |
| 44 | + assert family == tsocket.AF_UNIX |
| 45 | + |
| 46 | + sock = tsocket.socket(family=family) |
| 47 | + await sock.connect(sockaddr) |
| 48 | + return SocketStream(sock) |
| 49 | + |
| 50 | + |
| 51 | +@skip_if_not_unix |
18 | 52 | async def test_open_unix_listeners_basic() -> None:
|
19 | 53 | # Since we are on unix, we can use fun things like /tmp
|
20 |
| - listeners = await open_unix_listeners("/tmp/test_socket.sock", backlog=0) |
| 54 | + path = "/tmp/test_socket.sock" |
| 55 | + listeners = await open_unix_listeners(path, backlog=0) |
21 | 56 | assert isinstance(listeners, list)
|
22 | 57 | for obj in listeners:
|
23 |
| - assert isinstance(obj, SocketListener) |
24 |
| - # Binds to wildcard address by default |
25 |
| - assert obj.socket.family in [tsocket.AF_INET, tsocket.AF_INET6] |
26 |
| - assert obj.socket.getsockname()[0] in ["0.0.0.0", "::"] |
| 58 | + assert obj.socket.family == tsocket.AF_UNIX |
| 59 | + # Does not work because of atomic overwrite |
| 60 | + # assert obj.socket.getsockname() == path |
27 | 61 |
|
28 | 62 | listener = listeners[0]
|
29 | 63 | # Make sure the backlog is at least 2
|
30 |
| - c1 = await open_stream_to_socket_listener(listener) |
31 |
| - c2 = await open_stream_to_socket_listener(listener) |
| 64 | + c1 = await open_stream_to_unix_socket_listener(listener, path) |
| 65 | + c2 = await open_stream_to_unix_socket_listener(listener, path) |
32 | 66 |
|
33 | 67 | s1 = await listener.accept()
|
34 | 68 | s2 = await listener.accept()
|
|
0 commit comments