Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints to synapse.crypto. #11146

Merged
merged 3 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11146.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to `synapse.crypto`.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ files =
[mypy-synapse.api.*]
disallow_untyped_defs = True

[mypy-synapse.crypto.*]
disallow_untyped_defs = True

[mypy-synapse.events.*]
disallow_untyped_defs = True

Expand Down
40 changes: 26 additions & 14 deletions synapse/crypto/context_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
TLSVersion,
platformTrust,
)
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.python.failure import Failure
from twisted.web.iweb import IPolicyForHTTPS

from synapse.config.homeserver import HomeServerConfig

logger = logging.getLogger(__name__)


Expand All @@ -51,7 +54,7 @@ class ServerContextFactory(ContextFactory):
per https://github.com/matrix-org/synapse/issues/1691
"""

def __init__(self, config):
def __init__(self, config: HomeServerConfig):
# TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version,
# switch to those (see https://github.com/pyca/cryptography/issues/5379).
#
Expand All @@ -64,7 +67,7 @@ def __init__(self, config):
self.configure_context(self._context, config)

@staticmethod
def configure_context(context, config):
def configure_context(context: SSL.Context, config: HomeServerConfig) -> None:
try:
_ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
context.set_tmp_ecdh(_ecCurve)
Expand All @@ -75,14 +78,15 @@ def configure_context(context, config):
SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1
)
context.use_certificate_chain_file(config.tls.tls_certificate_file)
assert config.tls.tls_private_key is not None
context.use_privatekey(config.tls.tls_private_key)

# https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
context.set_cipher_list(
"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM"
b"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM"
)

def getContext(self):
def getContext(self) -> SSL.Context:
return self._context


Expand All @@ -98,7 +102,7 @@ class FederationPolicyForHTTPS:
constructs an SSLClientConnectionCreator factory accordingly.
"""

def __init__(self, config):
def __init__(self, config: HomeServerConfig):
self._config = config

# Check if we're using a custom list of a CA certificates
Expand Down Expand Up @@ -131,7 +135,7 @@ def __init__(self, config):
self._config.tls.federation_certificate_verification_whitelist
)

def get_options(self, host: bytes):
def get_options(self, host: bytes) -> IOpenSSLClientConnectionCreator:
# IPolicyForHTTPS.get_options takes bytes, but we want to compare
# against the str whitelist. The hostnames in the whitelist are already
# IDNA-encoded like the hosts will be here.
Expand All @@ -153,7 +157,9 @@ def get_options(self, host: bytes):

return SSLClientConnectionCreator(host, ssl_context, should_verify)

def creatorForNetloc(self, hostname, port):
def creatorForNetloc(
self, hostname: bytes, port: int
) -> IOpenSSLClientConnectionCreator:
"""Implements the IPolicyForHTTPS interface so that this can be passed
directly to agents.
"""
Expand All @@ -169,16 +175,18 @@ class RegularPolicyForHTTPS:
trust root.
"""

def __init__(self):
def __init__(self) -> None:
trust_root = platformTrust()
self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext()
self._ssl_context.set_info_callback(_context_info_cb)

def creatorForNetloc(self, hostname, port):
def creatorForNetloc(
self, hostname: bytes, port: int
) -> IOpenSSLClientConnectionCreator:
return SSLClientConnectionCreator(hostname, self._ssl_context, True)


def _context_info_cb(ssl_connection, where, ret):
def _context_info_cb(ssl_connection: SSL.Connection, where: int, ret: int) -> None:
"""The 'information callback' for our openssl context objects.

Note: Once this is set as the info callback on a Context object, the Context should
Expand All @@ -204,11 +212,13 @@ class SSLClientConnectionCreator:
Replaces twisted.internet.ssl.ClientTLSOptions
"""

def __init__(self, hostname: bytes, ctx, verify_certs: bool):
def __init__(self, hostname: bytes, ctx: SSL.Context, verify_certs: bool):
self._ctx = ctx
self._verifier = ConnectionVerifier(hostname, verify_certs)

def clientConnectionForTLS(self, tls_protocol):
def clientConnectionForTLS(
self, tls_protocol: TLSMemoryBIOProtocol
) -> SSL.Connection:
context = self._ctx
connection = SSL.Connection(context, None)

Expand All @@ -219,7 +229,7 @@ def clientConnectionForTLS(self, tls_protocol):
# ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the
# tls_protocol so that the SSL context's info callback has something to
# call to do the cert verification.
tls_protocol._synapse_tls_verifier = self._verifier
tls_protocol._synapse_tls_verifier = self._verifier # type: ignore[attr-defined]
return connection


Expand All @@ -244,7 +254,9 @@ def __init__(self, hostname: bytes, verify_certs: bool):
self._hostnameBytes = hostname
self._hostnameASCII = self._hostnameBytes.decode("ascii")

def verify_context_info_cb(self, ssl_connection, where):
def verify_context_info_cb(
self, ssl_connection: SSL.Connection, where: int
) -> None:
if where & SSL.SSL_CB_HANDSHAKE_START and not self._is_ip_address:
ssl_connection.set_tlsext_host_name(self._hostnameBytes)

Expand Down
2 changes: 1 addition & 1 deletion synapse/crypto/event_signing.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def compute_content_hash(


def compute_event_reference_hash(
event, hash_algorithm: Hasher = hashlib.sha256
event: EventBase, hash_algorithm: Hasher = hashlib.sha256
) -> Tuple[str, bytes]:
"""Computes the event reference hash. This is the hash of the redacted
event.
Expand Down
8 changes: 5 additions & 3 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def from_json_object(
server_name: str,
json_object: JsonDict,
minimum_valid_until_ms: int,
):
) -> "VerifyJsonRequest":
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
object for the given server.
"""
Expand All @@ -104,7 +104,7 @@ def from_event(
server_name: str,
event: EventBase,
minimum_valid_until_ms: int,
):
) -> "VerifyJsonRequest":
"""Create a VerifyJsonRequest to verify all signatures on an event
object for the given server.
"""
Expand Down Expand Up @@ -449,7 +449,9 @@ def __init__(self, hs: "HomeServer"):

self.store = hs.get_datastore()

async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]):
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
key_ids_to_fetch = (
(queue_value.server_name, key_id)
for queue_value in keys_to_fetch
Expand Down