Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fixes to federation_client dev script #14479

Merged
merged 5 commits into from
Nov 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14479.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`scripts-dev/federation_client`: Fix routing on servers with `.well-known` files.
122 changes: 87 additions & 35 deletions scripts-dev/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@
import signedjson.types
import srvlookup
import yaml
from requests import PreparedRequest, Response
from requests.adapters import HTTPAdapter
from urllib3 import HTTPConnectionPool

# uncomment the following to enable debug logging of http requests
# from httplib import HTTPConnection
# from http.client import HTTPConnection
# HTTPConnection.debuglevel = 1


Expand Down Expand Up @@ -103,6 +104,7 @@ def request(
destination: str,
path: str,
content: Optional[str],
verify_tls: bool,
) -> requests.Response:
if method is None:
if content is None:
Expand Down Expand Up @@ -141,7 +143,6 @@ def request(
s.mount("matrix://", MatrixConnectionAdapter())

headers: Dict[str, str] = {
"Host": destination,
"Authorization": authorization_headers[0],
}

Expand All @@ -152,7 +153,7 @@ def request(
method=method,
url=dest,
headers=headers,
verify=False,
verify=verify_tls,
data=content,
stream=True,
)
Expand Down Expand Up @@ -202,6 +203,12 @@ def main() -> None:

parser.add_argument("--body", help="Data to send as the body of the HTTP request")

parser.add_argument(
"--insecure",
action="store_true",
help="Disable TLS certificate verification",
)

parser.add_argument(
"path", help="request path, including the '/_matrix/federation/...' prefix."
)
Expand All @@ -227,6 +234,7 @@ def main() -> None:
args.destination,
args.path,
content=args.body,
verify_tls=not args.insecure,
)

sys.stderr.write("Status Code: %d\n" % (result.status_code,))
Expand Down Expand Up @@ -254,36 +262,93 @@ def read_args_from_config(args: argparse.Namespace) -> None:


class MatrixConnectionAdapter(HTTPAdapter):
def send(
self,
request: PreparedRequest,
*args: Any,
**kwargs: Any,
) -> Response:
# overrides the send() method in the base class.

# We need to look for .well-known redirects before passing the request up to
# HTTPAdapter.send().
assert isinstance(request.url, str)
parsed = urlparse.urlsplit(request.url)
server_name = parsed.netloc
well_known = self._get_well_known(parsed.netloc)

if well_known:
server_name = well_known

# replace the scheme in the uri with https, so that cert verification is done
# also replace the hostname if we got a .well-known result
request.url = urlparse.urlunsplit(
("https", server_name, parsed.path, parsed.query, parsed.fragment)
)

# at this point we also add the host header (otherwise urllib will add one
# based on the `host` from the connection returned by `get_connection`,
# which will be wrong if there is an SRV record).
request.headers["Host"] = server_name

return super().send(request, *args, **kwargs)

def get_connection(
self, url: str, proxies: Optional[Dict[str, str]] = None
) -> HTTPConnectionPool:
# overrides the get_connection() method in the base class
parsed = urlparse.urlsplit(url)
(host, port, ssl_server_name) = self._lookup(parsed.netloc)
print(
f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr
)
return self.poolmanager.connection_from_host(
host,
port=port,
scheme="https",
pool_kwargs={"server_hostname": ssl_server_name},
)

@staticmethod
def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]:
if s[-1] == "]":
def _lookup(server_name: str) -> Tuple[str, int, str]:
"""
Do an SRV lookup on a server name and return the host:port to connect to
Given the server_name (after any .well-known lookup), return the host, port and
the ssl server name
"""
if server_name[-1] == "]":
# ipv6 literal (with no port)
return s, 8448
return server_name, 8448, server_name

if ":" in s:
out = s.rsplit(":", 1)
if ":" in server_name:
# explicit port
out = server_name.rsplit(":", 1)
try:
port = int(out[1])
except ValueError:
raise ValueError("Invalid host:port '%s'" % s)
return out[0], port

# try a .well-known lookup
if not skip_well_known:
well_known = MatrixConnectionAdapter.get_well_known(s)
if well_known:
return MatrixConnectionAdapter.lookup(well_known, skip_well_known=True)
raise ValueError("Invalid host:port '%s'" % (server_name,))
return out[0], port, out[0]

try:
srv = srvlookup.lookup("matrix", "tcp", s)[0]
return srv.host, srv.port
srv = srvlookup.lookup("matrix", "tcp", server_name)[0]
print(
f"SRV lookup on _matrix._tcp.{server_name} gave {srv}",
file=sys.stderr,
)
return srv.host, srv.port, server_name
except Exception:
return s, 8448
return server_name, 8448, server_name

@staticmethod
def get_well_known(server_name: str) -> Optional[str]:
uri = "https://%s/.well-known/matrix/server" % (server_name,)
print("fetching %s" % (uri,), file=sys.stderr)
def _get_well_known(server_name: str) -> Optional[str]:
if ":" in server_name:
# explicit port, or ipv6 literal. Either way, no .well-known
return None

# TODO: check for ipv4 literals

uri = f"https://{server_name}/.well-known/matrix/server"
print(f"fetching {uri}", file=sys.stderr)

try:
resp = requests.get(uri)
Expand All @@ -304,19 +369,6 @@ def get_well_known(server_name: str) -> Optional[str]:
print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
return None

def get_connection(
self, url: str, proxies: Optional[Dict[str, str]] = None
) -> HTTPConnectionPool:
parsed = urlparse.urlparse(url)

(host, port) = self.lookup(parsed.netloc)
netloc = "%s:%d" % (host, port)
print("Connecting to %s" % (netloc,), file=sys.stderr)
url = urlparse.urlunparse(
("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)
)
return super().get_connection(url, proxies)


if __name__ == "__main__":
main()