Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fixes to federation_client dev script (#14479)
Browse files Browse the repository at this point in the history
* Attempt to fix federation-client devscript handling of .well-known

The script was setting the wrong value in the Host header

* Fix TLS verification

Turns out that actually doing TLS verification isn't that hard. Let's enable
it.
  • Loading branch information
richvdh authored Nov 20, 2022
1 parent e1b15f2 commit 8d133a8
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 35 deletions.
1 change: 1 addition & 0 deletions changelog.d/14479.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`scripts-dev/federation_client`: Fix routing on servers with `.well-known` files.
122 changes: 87 additions & 35 deletions scripts-dev/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@
import signedjson.types
import srvlookup
import yaml
from requests import PreparedRequest, Response
from requests.adapters import HTTPAdapter
from urllib3 import HTTPConnectionPool

# uncomment the following to enable debug logging of http requests
# from httplib import HTTPConnection
# from http.client import HTTPConnection
# HTTPConnection.debuglevel = 1


Expand Down Expand Up @@ -103,6 +104,7 @@ def request(
destination: str,
path: str,
content: Optional[str],
verify_tls: bool,
) -> requests.Response:
if method is None:
if content is None:
Expand Down Expand Up @@ -141,7 +143,6 @@ def request(
s.mount("matrix://", MatrixConnectionAdapter())

headers: Dict[str, str] = {
"Host": destination,
"Authorization": authorization_headers[0],
}

Expand All @@ -152,7 +153,7 @@ def request(
method=method,
url=dest,
headers=headers,
verify=False,
verify=verify_tls,
data=content,
stream=True,
)
Expand Down Expand Up @@ -202,6 +203,12 @@ def main() -> None:

parser.add_argument("--body", help="Data to send as the body of the HTTP request")

parser.add_argument(
"--insecure",
action="store_true",
help="Disable TLS certificate verification",
)

parser.add_argument(
"path", help="request path, including the '/_matrix/federation/...' prefix."
)
Expand All @@ -227,6 +234,7 @@ def main() -> None:
args.destination,
args.path,
content=args.body,
verify_tls=not args.insecure,
)

sys.stderr.write("Status Code: %d\n" % (result.status_code,))
Expand Down Expand Up @@ -254,36 +262,93 @@ def read_args_from_config(args: argparse.Namespace) -> None:


class MatrixConnectionAdapter(HTTPAdapter):
def send(
self,
request: PreparedRequest,
*args: Any,
**kwargs: Any,
) -> Response:
# overrides the send() method in the base class.

# We need to look for .well-known redirects before passing the request up to
# HTTPAdapter.send().
assert isinstance(request.url, str)
parsed = urlparse.urlsplit(request.url)
server_name = parsed.netloc
well_known = self._get_well_known(parsed.netloc)

if well_known:
server_name = well_known

# replace the scheme in the uri with https, so that cert verification is done
# also replace the hostname if we got a .well-known result
request.url = urlparse.urlunsplit(
("https", server_name, parsed.path, parsed.query, parsed.fragment)
)

# at this point we also add the host header (otherwise urllib will add one
# based on the `host` from the connection returned by `get_connection`,
# which will be wrong if there is an SRV record).
request.headers["Host"] = server_name

return super().send(request, *args, **kwargs)

def get_connection(
self, url: str, proxies: Optional[Dict[str, str]] = None
) -> HTTPConnectionPool:
# overrides the get_connection() method in the base class
parsed = urlparse.urlsplit(url)
(host, port, ssl_server_name) = self._lookup(parsed.netloc)
print(
f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr
)
return self.poolmanager.connection_from_host(
host,
port=port,
scheme="https",
pool_kwargs={"server_hostname": ssl_server_name},
)

@staticmethod
def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]:
if s[-1] == "]":
def _lookup(server_name: str) -> Tuple[str, int, str]:
"""
Do an SRV lookup on a server name and return the host:port to connect to
Given the server_name (after any .well-known lookup), return the host, port and
the ssl server name
"""
if server_name[-1] == "]":
# ipv6 literal (with no port)
return s, 8448
return server_name, 8448, server_name

if ":" in s:
out = s.rsplit(":", 1)
if ":" in server_name:
# explicit port
out = server_name.rsplit(":", 1)
try:
port = int(out[1])
except ValueError:
raise ValueError("Invalid host:port '%s'" % s)
return out[0], port

# try a .well-known lookup
if not skip_well_known:
well_known = MatrixConnectionAdapter.get_well_known(s)
if well_known:
return MatrixConnectionAdapter.lookup(well_known, skip_well_known=True)
raise ValueError("Invalid host:port '%s'" % (server_name,))
return out[0], port, out[0]

try:
srv = srvlookup.lookup("matrix", "tcp", s)[0]
return srv.host, srv.port
srv = srvlookup.lookup("matrix", "tcp", server_name)[0]
print(
f"SRV lookup on _matrix._tcp.{server_name} gave {srv}",
file=sys.stderr,
)
return srv.host, srv.port, server_name
except Exception:
return s, 8448
return server_name, 8448, server_name

@staticmethod
def get_well_known(server_name: str) -> Optional[str]:
uri = "https://%s/.well-known/matrix/server" % (server_name,)
print("fetching %s" % (uri,), file=sys.stderr)
def _get_well_known(server_name: str) -> Optional[str]:
if ":" in server_name:
# explicit port, or ipv6 literal. Either way, no .well-known
return None

# TODO: check for ipv4 literals

uri = f"https://{server_name}/.well-known/matrix/server"
print(f"fetching {uri}", file=sys.stderr)

try:
resp = requests.get(uri)
Expand All @@ -304,19 +369,6 @@ def get_well_known(server_name: str) -> Optional[str]:
print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
return None

def get_connection(
self, url: str, proxies: Optional[Dict[str, str]] = None
) -> HTTPConnectionPool:
parsed = urlparse.urlparse(url)

(host, port) = self.lookup(parsed.netloc)
netloc = "%s:%d" % (host, port)
print("Connecting to %s" % (netloc,), file=sys.stderr)
url = urlparse.urlunparse(
("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)
)
return super().get_connection(url, proxies)


if __name__ == "__main__":
main()

0 comments on commit 8d133a8

Please sign in to comment.