Skip to content

Commit a82874e

Browse files
committed
Override get_connection_with_tls_context instead
This PR replaces the deprecated `get_connection` method in `MatrixConnectionAdapter` with a similar implementation in `get_connection_with_tls_context`. It also fixes a couple mypy errors that were firing on the arguments of `get_connection`.
1 parent 30e9f6e commit a82874e

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

scripts-dev/federation_client.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import base64
4444
import json
4545
import sys
46-
from typing import Any, Dict, Optional, Tuple
46+
from typing import Any, Dict, Mapping, Optional, Tuple, Union
4747
from urllib import parse as urlparse
4848

4949
import requests
@@ -75,7 +75,7 @@ def encode_canonical_json(value: object) -> bytes:
7575
value,
7676
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes
7777
ensure_ascii=False,
78-
# Remove unecessary white space.
78+
# Remove unnecessary white space.
7979
separators=(",", ":"),
8080
# Sort the keys of dictionaries.
8181
sort_keys=True,
@@ -298,12 +298,41 @@ def send(
298298

299299
return super().send(request, *args, **kwargs)
300300

301-
def get_connection(
302-
self, url: str, proxies: Optional[Dict[str, str]] = None
301+
# def get_connection(
302+
# self, url: str, proxies: Optional[Dict[str, str]] = None,
303+
# ) -> HTTPConnectionPool:
304+
# # overrides the get_connection() method in the base class
305+
# parsed = urlparse.urlsplit(url)
306+
# (host, port, ssl_server_name) = self._lookup(parsed.netloc)
307+
# print(
308+
# f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr
309+
# )
310+
# return self.poolmanager.connection_from_host(
311+
# host,
312+
# port=port,
313+
# scheme="https",
314+
# pool_kwargs={"server_hostname": ssl_server_name},
315+
# )
316+
317+
def get_connection_with_tls_context(
318+
self,
319+
request: PreparedRequest,
320+
verify: Optional[Union[bool, str]],
321+
proxies: Optional[Mapping[str, str]] = None,
322+
cert: Optional[Union[Tuple[str, str], str]] = None,
303323
) -> HTTPConnectionPool:
324+
# overrides the get_connection_with_tls_context() method in the base class
325+
# return self.get_connection(request.url, proxies)
304326
# overrides the get_connection() method in the base class
305-
parsed = urlparse.urlsplit(url)
306-
(host, port, ssl_server_name) = self._lookup(parsed.netloc)
327+
parsed = urlparse.urlsplit(request.url)
328+
329+
# Extract the hostname from the request URL and ensure it's a str.
330+
hostname = parsed.netloc
331+
if isinstance(hostname, bytes):
332+
hostname = hostname.decode("utf-8")
333+
assert isinstance(hostname, str)
334+
335+
(host, port, ssl_server_name) = self._lookup(hostname)
307336
print(
308337
f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr
309338
)

0 commit comments

Comments
 (0)