From 0fd3068a4715a4f67f4c5e2c57cb085bb94fd015 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 24 Mar 2021 20:52:11 +0100 Subject: [PATCH 1/9] add B006 and B008 fixes --- contrib/cmdclient/console.py | 4 ++- contrib/cmdclient/http.py | 15 +++++++---- setup.cfg | 4 +-- synapse/appservice/scheduler.py | 4 +-- synapse/config/ratelimiting.py | 4 ++- synapse/events/__init__.py | 14 +++++++--- synapse/federation/units.py | 4 +-- synapse/handlers/appservice.py | 4 +-- synapse/handlers/federation.py | 4 +-- synapse/handlers/message.py | 11 +++++--- synapse/handlers/register.py | 4 ++- synapse/handlers/sync.py | 8 +++--- synapse/http/client.py | 4 +-- synapse/http/proxyagent.py | 6 +++-- synapse/logging/opentracing.py | 3 ++- synapse/module_api/__init__.py | 10 +++---- synapse/notifier.py | 19 ++++++++------ synapse/storage/database.py | 19 +++++++++----- synapse/storage/databases/main/events.py | 7 +++-- .../storage/databases/main/group_server.py | 4 ++- synapse/storage/databases/main/state.py | 6 +++-- synapse/storage/databases/state/bg_updates.py | 3 ++- synapse/storage/databases/state/store.py | 3 ++- synapse/storage/state.py | 26 ++++++++++++------- synapse/storage/util/id_generators.py | 11 ++++++-- synapse/util/caches/lrucache.py | 14 +++++----- .../test_matrix_federation_agent.py | 8 +++--- tests/replication/_base.py | 4 +-- .../replication/slave/storage/test_events.py | 16 +++++++----- tests/rest/client/v1/test_rooms.py | 5 +++- tests/rest/client/v1/utils.py | 14 +++++++--- tests/rest/client/v2_alpha/test_relations.py | 4 +-- tests/storage/test_id_generators.py | 14 +++++----- tests/storage/test_redaction.py | 4 +-- tests/test_state.py | 5 ++-- tests/test_visibility.py | 6 +++-- tests/util/test_ratelimitutils.py | 4 +-- 37 files changed, 187 insertions(+), 112 deletions(-) diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 67e032244ecc..243b658425f1 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -718,7 +718,7 @@ def _run_and_pprint( method, path, data=None, - query_params={"access_token": None}, + query_params: dict = None, alt_text=None, ): """Runs an HTTP request and pretty prints the output. @@ -729,6 +729,8 @@ def _run_and_pprint( data: Raw JSON data if any query_params: dict of query parameters to add to the url """ + query_params = query_params or {"access_token": None} + url = self._url() + path if "access_token" in query_params: query_params["access_token"] = self._tok() diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index 851e80c25bb4..a381af7cc92f 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -85,8 +85,9 @@ def get_json(self, url, args=None): body = yield readBody(response) defer.returnValue(json.loads(body)) - def _create_put_request(self, url, json_data, headers_dict={}): + def _create_put_request(self, url, json_data, headers_dict: dict = None): """Wrapper of _create_request to issue a PUT request""" + headers_dict = headers_dict or {} if "Content-Type" not in headers_dict: raise defer.error(RuntimeError("Must include Content-Type header for PUTs")) @@ -95,14 +96,16 @@ def _create_put_request(self, url, json_data, headers_dict={}): "PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict ) - def _create_get_request(self, url, headers_dict={}): + def _create_get_request(self, url, headers_dict: dict = None): """Wrapper of _create_request to issue a GET request""" - return self._create_request("GET", url, headers_dict=headers_dict) + return self._create_request("GET", url, headers_dict=headers_dict or {}) @defer.inlineCallbacks def do_request( - self, method, url, data=None, qparams=None, jsonreq=True, headers={} + self, method, url, data=None, qparams=None, jsonreq=True, headers: dict = None ): + headers = headers or {} + if qparams: url = "%s?%s" % (url, urllib.urlencode(qparams, True)) @@ -123,8 +126,10 @@ def do_request( defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def _create_request(self, method, url, producer=None, headers_dict={}): + def _create_request(self, method, url, producer=None, headers_dict: dict = None): """Creates and sends a request to the given url""" + headers_dict = headers_dict or {} + headers_dict["User-Agent"] = ["Synapse Cmd Client"] retries_left = 5 diff --git a/setup.cfg b/setup.cfg index 7329eed213d7..5fdb51ac7397 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,8 +18,8 @@ ignore = # E203: whitespace before ':' (which is contrary to pep8?) # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) -# B00*: Subsection of the bugbear suite (TODO: add in remaining fixes) -ignore=W503,W504,E203,E731,E501,B006,B007,B008 +# B007: Subsection of the bugbear suite (TODO: add in remaining fixes) +ignore=W503,W504,E203,E731,E501,B007 [isort] line_length = 88 diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 366c476f807a..f49ea611b417 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -191,11 +191,11 @@ async def send( self, service: ApplicationService, events: List[EventBase], - ephemeral: List[JsonDict] = [], + ephemeral: List[JsonDict] = None, ): try: txn = await self.store.create_appservice_txn( - service=service, events=events, ephemeral=ephemeral + service=service, events=events, ephemeral=ephemeral or [] ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 3f3997f4e53b..80adf5c42080 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -21,8 +21,10 @@ class RateLimitConfig: def __init__( self, config: Dict[str, float], - defaults={"per_second": 0.17, "burst_count": 3.0}, + defaults: Dict[str, float] = None, ): + defaults = defaults or {"per_second": 0.17, "burst_count": 3.0} + self.per_second = config.get("per_second", defaults["per_second"]) self.burst_count = int(config.get("burst_count", defaults["burst_count"])) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 8f6b955d17b7..6abcab8374e2 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -330,9 +330,11 @@ def __init__( self, event_dict: JsonDict, room_version: RoomVersion, - internal_metadata_dict: JsonDict = {}, + internal_metadata_dict: JsonDict = None, rejected_reason: Optional[str] = None, ): + internal_metadata_dict = internal_metadata_dict or {} + event_dict = dict(event_dict) # Signatures is a dict of dicts, and this is faster than doing a @@ -386,9 +388,11 @@ def __init__( self, event_dict: JsonDict, room_version: RoomVersion, - internal_metadata_dict: JsonDict = {}, + internal_metadata_dict: JsonDict = None, rejected_reason: Optional[str] = None, ): + internal_metadata_dict = internal_metadata_dict or {} + event_dict = dict(event_dict) # Signatures is a dict of dicts, and this is faster than doing a @@ -507,9 +511,11 @@ def _event_type_from_format_version(format_version: int) -> Type[EventBase]: def make_event_from_dict( event_dict: JsonDict, room_version: RoomVersion = RoomVersions.V1, - internal_metadata_dict: JsonDict = {}, + internal_metadata_dict: JsonDict = None, rejected_reason: Optional[str] = None, ) -> EventBase: """Construct an EventBase from the given event dict""" event_type = _event_type_from_format_version(room_version.event_format) - return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason) + return event_type( + event_dict, room_version, internal_metadata_dict or {}, rejected_reason + ) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index b662c4262120..2e96494c678c 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -98,7 +98,7 @@ class Transaction(JsonEncodedObject): "pdus", ] - def __init__(self, transaction_id=None, pdus=[], **kwargs): + def __init__(self, transaction_id=None, pdus: list = None, **kwargs): """If we include a list of pdus then we decode then as PDU's automatically. """ @@ -107,7 +107,7 @@ def __init__(self, transaction_id=None, pdus=[], **kwargs): if "edus" in kwargs and not kwargs["edus"]: del kwargs["edus"] - super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs) + super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs) @staticmethod def create_new(pdus, **kwargs): diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 996f9e5debc8..95413956b661 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -182,7 +182,7 @@ def notify_interested_services_ephemeral( self, stream_key: str, new_token: Optional[int], - users: Collection[Union[str, UserID]] = [], + users: Collection[Union[str, UserID]] = None, ): """This is called by the notifier in the background when a ephemeral event handled by the homeserver. @@ -215,7 +215,7 @@ def notify_interested_services_ephemeral( # We only start a new background process if necessary rather than # optimistically (to cut down on overhead). self._notify_interested_services_ephemeral( - services, stream_key, new_token, users + services, stream_key, new_token, users or [] ) @wrap_as_background_process("notify_interested_services_ephemeral") diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 598a66f74cf4..250c5250edb5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1772,7 +1772,7 @@ async def _make_and_verify_event( room_id: str, user_id: str, membership: str, - content: JsonDict = {}, + content: JsonDict = None, params: Optional[Dict[str, Union[str, Iterable[str]]]] = None, ) -> Tuple[str, EventBase, RoomVersion]: ( @@ -1780,7 +1780,7 @@ async def _make_and_verify_event( event, room_version, ) = await self.federation_client.make_membership_event( - target_hosts, room_id, user_id, membership, content, params=params + target_hosts, room_id, user_id, membership, content or {}, params=params ) logger.debug("Got response to make_%s: %s", membership, event) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 1b7c065b34e1..e7946f3061a4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -137,7 +137,7 @@ async def get_state_events( self, user_id: str, room_id: str, - state_filter: StateFilter = StateFilter.all(), + state_filter: StateFilter = None, at_token: Optional[StreamToken] = None, is_guest: bool = False, ) -> List[dict]: @@ -186,7 +186,7 @@ async def get_state_events( event = last_events[0] if visible_events: room_state = await self.state_store.get_state_for_events( - [event.event_id], state_filter=state_filter + [event.event_id], state_filter=state_filter or StateFilter.all() ) room_state = room_state[event.event_id] else: @@ -874,7 +874,7 @@ async def handle_new_client_event( event: EventBase, context: EventContext, ratelimit: bool = True, - extra_users: List[UserID] = [], + extra_users: List[UserID] = None, ignore_shadow_ban: bool = False, ) -> EventBase: """Processes a new event. @@ -902,6 +902,7 @@ async def handle_new_client_event( Raises: ShadowBanError if the requester has been shadow-banned. """ + extra_users = extra_users or [] # we don't apply shadow-banning to membership events here. Invites are blocked # higher up the stack, and we allow shadow-banned users to send join and leave @@ -1071,7 +1072,7 @@ async def persist_and_notify_client_event( event: EventBase, context: EventContext, ratelimit: bool = True, - extra_users: List[UserID] = [], + extra_users: List[UserID] = None, ) -> EventBase: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -1083,6 +1084,8 @@ async def persist_and_notify_client_event( it was de-duplicated (e.g. because we had already persisted an event with the same transaction ID.) """ + extra_users = extra_users or [] + assert self.storage.persistence is not None assert self._events_shard_config.should_handle( self._instance_name, event.room_id diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0fc2bf15d520..d1281d9b28f8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -169,7 +169,7 @@ async def register_user( user_type: Optional[str] = None, default_display_name: Optional[str] = None, address: Optional[str] = None, - bind_emails: Iterable[str] = [], + bind_emails: Iterable[str] = None, by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, auth_provider_id: Optional[str] = None, @@ -204,6 +204,8 @@ async def register_user( Raises: SynapseError if there was a problem registering. """ + bind_emails = bind_emails or [] + self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ee607e6e6576..bf1c4aba12e2 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -540,7 +540,7 @@ async def _load_filtered_recents( ) async def get_state_after_event( - self, event: EventBase, state_filter: StateFilter = StateFilter.all() + self, event: EventBase, state_filter: StateFilter = None ) -> StateMap[str]: """ Get the room state after the given event @@ -550,7 +550,7 @@ async def get_state_after_event( state_filter: The state filter used to fetch state from the database. """ state_ids = await self.state_store.get_state_ids_for_event( - event.event_id, state_filter=state_filter + event.event_id, state_filter=state_filter or StateFilter.all() ) if event.is_state(): state_ids = dict(state_ids) @@ -561,7 +561,7 @@ async def get_state_at( self, room_id: str, stream_position: StreamToken, - state_filter: StateFilter = StateFilter.all(), + state_filter: StateFilter = None, ) -> StateMap[str]: """Get the room state at a particular stream position @@ -581,7 +581,7 @@ async def get_state_at( if last_events: last_event = last_events[-1] state = await self.get_state_after_event( - last_event, state_filter=state_filter + last_event, state_filter=state_filter or StateFilter.all() ) else: diff --git a/synapse/http/client.py b/synapse/http/client.py index a0caba84e4ba..5ce21ad458ce 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -297,7 +297,7 @@ class SimpleHttpClient: def __init__( self, hs: "HomeServer", - treq_args: Dict[str, Any] = {}, + treq_args: Dict[str, Any] = None, ip_whitelist: Optional[IPSet] = None, ip_blacklist: Optional[IPSet] = None, use_proxy: bool = False, @@ -317,7 +317,7 @@ def __init__( self._ip_whitelist = ip_whitelist self._ip_blacklist = ip_blacklist - self._extra_treq_args = treq_args + self._extra_treq_args = treq_args or {} self.user_agent = hs.version_string self.clock = hs.get_clock() diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 16ec850064dd..3d2aee933641 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -27,7 +27,7 @@ from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent +from twisted.web.iweb import IAgent, IPolicyForHTTPS from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint @@ -88,12 +88,14 @@ def __init__( self, reactor, proxy_reactor=None, - contextFactory=BrowserLikePolicyForHTTPS(), + contextFactory: IPolicyForHTTPS = None, connectTimeout=None, bindAddress=None, pool=None, use_proxy=False, ): + contextFactory = contextFactory or BrowserLikePolicyForHTTPS() + _AgentBase.__init__(self, reactor, pool) if proxy_reactor is None: diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 10bd4a14614b..f6a8c420b554 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -478,7 +478,7 @@ def start_active_span_from_request( def start_active_span_from_edu( edu_content, operation_name, - references=[], + references: list = None, tags=None, start_time=None, ignore_active_span=False, @@ -493,6 +493,7 @@ def start_active_span_from_edu( For the other args see opentracing.tracer """ + references = references or [] if opentracing is None: return noop_context_manager() diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 781e02fbbb31..2cc062cdf61d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple +from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple from twisted.internet import defer @@ -118,7 +118,7 @@ def check_user_exists(self, user_id): return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) @defer.inlineCallbacks - def register(self, localpart, displayname=None, emails=[]): + def register(self, localpart, displayname=None, emails: List[str] = None): """Registers a new user with given localpart and optional displayname, emails. Also returns an access token for the new user. @@ -138,11 +138,11 @@ def register(self, localpart, displayname=None, emails=[]): logger.warning( "Using deprecated ModuleApi.register which creates a dummy user device." ) - user_id = yield self.register_user(localpart, displayname, emails) + user_id = yield self.register_user(localpart, displayname, emails or []) _, access_token = yield self.register_device(user_id) return user_id, access_token - def register_user(self, localpart, displayname=None, emails=[]): + def register_user(self, localpart, displayname=None, emails: List[str] = None): """Registers a new user with given localpart and optional displayname, emails. Args: @@ -161,7 +161,7 @@ def register_user(self, localpart, displayname=None, emails=[]): self._hs.get_registration_handler().register_user( localpart=localpart, default_display_name=displayname, - bind_emails=emails, + bind_emails=emails or [], ) ) diff --git a/synapse/notifier.py b/synapse/notifier.py index 1374aae49051..d7b9a8043c83 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -266,7 +266,7 @@ def on_new_room_event( event: EventBase, event_pos: PersistedEventPosition, max_room_stream_token: RoomStreamToken, - extra_users: Collection[UserID] = [], + extra_users: Collection[UserID] = None, ): """Unwraps event and calls `on_new_room_event_args`.""" self.on_new_room_event_args( @@ -276,7 +276,7 @@ def on_new_room_event( state_key=event.get("state_key"), membership=event.content.get("membership"), max_room_stream_token=max_room_stream_token, - extra_users=extra_users, + extra_users=extra_users or [], ) def on_new_room_event_args( @@ -287,7 +287,7 @@ def on_new_room_event_args( membership: Optional[str], event_pos: PersistedEventPosition, max_room_stream_token: RoomStreamToken, - extra_users: Collection[UserID] = [], + extra_users: Collection[UserID] = None, ): """Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -303,7 +303,7 @@ def on_new_room_event_args( self.pending_new_room_events.append( _PendingRoomEventEntry( event_pos=event_pos, - extra_users=extra_users, + extra_users=extra_users or [], room_id=room_id, type=event_type, state_key=state_key, @@ -372,14 +372,14 @@ def _notify_app_services_ephemeral( self, stream_key: str, new_token: Union[int, RoomStreamToken], - users: Collection[Union[str, UserID]] = [], + users: Collection[Union[str, UserID]] = None, ): try: stream_token = None if isinstance(new_token, int): stream_token = new_token self.appservice_handler.notify_interested_services_ephemeral( - stream_key, stream_token, users + stream_key, stream_token, users or [] ) except Exception: logger.exception("Error notifying application services of event") @@ -394,13 +394,16 @@ def on_new_event( self, stream_key: str, new_token: Union[int, RoomStreamToken], - users: Collection[Union[str, UserID]] = [], - rooms: Collection[str] = [], + users: Collection[Union[str, UserID]] = None, + rooms: Collection[str] = None, ): """Used to inform listeners that something has happened event wise. Will wake up all listeners for the given users and rooms. """ + users = users or [] + rooms = rooms or [] + with Measure(self.clock, "on_new_event"): user_streams = set() diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 5b0b9a20bf86..cf42709d4482 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -900,7 +900,7 @@ async def simple_upsert( table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Dict[str, Any] = None, desc: str = "simple_upsert", lock: bool = True, ) -> Optional[bool]: @@ -927,6 +927,8 @@ async def simple_upsert( Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + attempts = 0 while True: try: @@ -964,7 +966,7 @@ def simple_upsert_txn( table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Dict[str, Any] = None, lock: bool = True, ) -> Optional[bool]: """ @@ -982,6 +984,8 @@ def simple_upsert_txn( Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values @@ -1003,7 +1007,7 @@ def simple_upsert_txn_emulated( table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Dict[str, Any] = None, lock: bool = True, ) -> bool: """ @@ -1017,6 +1021,8 @@ def simple_upsert_txn_emulated( Returns True if a new entry was created, False if an existing one was updated. """ + insertion_values = insertion_values or {} + # We need to lock the table :(, unless we're *really* careful if lock: self.engine.lock_table(txn, table) @@ -1077,7 +1083,7 @@ def simple_upsert_txn_native_upsert( table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = {}, + insertion_values: Dict[str, Any] = None, ) -> None: """ Use the native UPSERT functionality in recent PostgreSQL versions. @@ -1090,7 +1096,7 @@ def simple_upsert_txn_native_upsert( """ allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) - allvalues.update(insertion_values) + allvalues.update(insertion_values or {}) if not values: latter = "NOTHING" @@ -1513,7 +1519,7 @@ async def simple_select_many_batch( column: str, iterable: Iterable[Any], retcols: Iterable[str], - keyvalues: Dict[str, Any] = {}, + keyvalues: Dict[str, Any] = None, desc: str = "simple_select_many_batch", batch_size: int = 100, ) -> List[Any]: @@ -1531,6 +1537,7 @@ async def simple_select_many_batch( desc: description of the transaction, for logging and metrics batch_size: the number of rows for each select query """ + keyvalues = keyvalues or {} results = [] # type: List[Dict[str, Any]] if not iterable: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 98dac19a9525..053daf8b6ae2 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -320,8 +320,8 @@ def _persist_events_txn( txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool, - state_delta_for_room: Dict[str, DeltaState] = {}, - new_forward_extremeties: Dict[str, List[str]] = {}, + state_delta_for_room: Dict[str, DeltaState] = None, + new_forward_extremeties: Dict[str, List[str]] = None, ): """Insert some number of room events into the necessary database tables. @@ -342,6 +342,9 @@ def _persist_events_txn( extremities. """ + state_delta_for_room = state_delta_for_room or {} + new_forward_extremeties = new_forward_extremeties or {} + all_events_and_contexts = events_and_contexts min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index ac07e0197b88..47f02745abe5 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1171,7 +1171,7 @@ async def register_user_group_membership( user_id: str, membership: str, is_admin: bool = False, - content: JsonDict = {}, + content: JsonDict = None, local_attestation: Optional[dict] = None, remote_attestation: Optional[dict] = None, is_publicised: bool = False, @@ -1192,6 +1192,8 @@ async def register_user_group_membership( is_publicised: Whether this should be publicised. """ + content = content or {} + def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? self.db_pool.simple_delete_txn( diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index a7f371732fd7..6b6dd4135faa 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -190,7 +190,7 @@ def _get_current_state_ids_txn(txn): # FIXME: how should this be cached? async def get_filtered_current_state_ids( - self, room_id: str, state_filter: StateFilter = StateFilter.all() + self, room_id: str, state_filter: StateFilter = None ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result @@ -205,7 +205,9 @@ async def get_filtered_current_state_ids( Map from type/state_key to event ID. """ - where_clause, where_args = state_filter.make_sql_filter_clause() + where_clause, where_args = ( + state_filter or StateFilter.all() + ).make_sql_filter_clause() if not where_clause: # We delegate to the cached version diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 1fd333b707e1..df86b042c7b2 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -73,8 +73,9 @@ def _count_state_group_hops_txn(self, txn, state_group): return count def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() + self, txn, groups, state_filter: StateFilter = None ): + state_filter = state_filter or StateFilter.all() results = {group: {} for group in groups} where_clause, where_args = state_filter.make_sql_filter_clause() diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index e2240703a784..3979c14a4378 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -209,7 +209,7 @@ def _get_state_for_group_using_cache(self, cache, group, state_filter): return state_filter.filter_state(state_dict_ids), not missing_types async def _get_state_for_groups( - self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + self, groups: Iterable[int], state_filter: StateFilter = None ) -> Dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -222,6 +222,7 @@ async def _get_state_for_groups( Returns: Dict of state group to state map. """ + state_filter = state_filter or StateFilter.all() member_filter, non_member_filter = state_filter.get_member_split() diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 2e277a21c458..2987b148471c 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -449,7 +449,7 @@ def _get_state_groups_from_groups( return self.stores.state._get_state_groups_from_groups(groups, state_filter) async def get_state_for_events( - self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() + self, event_ids: Iterable[str], state_filter: StateFilter = None ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -465,7 +465,7 @@ async def get_state_for_events( groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter + groups, state_filter or StateFilter.all() ) state_event_map = await self.stores.main.get_events( @@ -485,7 +485,7 @@ async def get_state_for_events( return {event: event_to_state[event] for event in event_ids} async def get_state_ids_for_events( - self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all() + self, event_ids: Iterable[str], state_filter: StateFilter = None ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids @@ -502,7 +502,7 @@ async def get_state_ids_for_events( groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter + groups, state_filter or StateFilter.all() ) event_to_state = { @@ -513,7 +513,7 @@ async def get_state_ids_for_events( return {event: event_to_state[event] for event in event_ids} async def get_state_for_event( - self, event_id: str, state_filter: StateFilter = StateFilter.all() + self, event_id: str, state_filter: StateFilter = None ) -> StateMap[EventBase]: """ Get the state dict corresponding to a particular event @@ -525,11 +525,13 @@ async def get_state_for_event( Returns: A dict from (type, state_key) -> state_event """ - state_map = await self.get_state_for_events([event_id], state_filter) + state_map = await self.get_state_for_events( + [event_id], state_filter or StateFilter.all() + ) return state_map[event_id] async def get_state_ids_for_event( - self, event_id: str, state_filter: StateFilter = StateFilter.all() + self, event_id: str, state_filter: StateFilter = None ) -> StateMap[str]: """ Get the state dict corresponding to a particular event @@ -541,11 +543,13 @@ async def get_state_ids_for_event( Returns: A dict from (type, state_key) -> state_event """ - state_map = await self.get_state_ids_for_events([event_id], state_filter) + state_map = await self.get_state_ids_for_events( + [event_id], state_filter or StateFilter.all() + ) return state_map[event_id] def _get_state_for_groups( - self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + self, groups: Iterable[int], state_filter: StateFilter = None ) -> Awaitable[Dict[int, MutableStateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -558,7 +562,9 @@ def _get_state_for_groups( Returns: Dict of state group to state map. """ - return self.stores.state._get_state_for_groups(groups, state_filter) + return self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) async def store_state_group( self, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index d4643c4fdf30..0d8653dda806 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -17,7 +17,7 @@ import threading from collections import OrderedDict from contextlib import contextmanager -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import attr @@ -91,7 +91,14 @@ class StreamIdGenerator: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[], step=1): + def __init__( + self, + db_conn, + table, + column, + extra_tables: Iterable[Tuple] = frozenset(), + step=1, + ): assert step != 0 self._lock = threading.Lock() self._step = step diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 60bb6ff642f2..a609d25741b3 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -57,12 +57,12 @@ def enumerate_leaves(node, depth): class _Node: __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] - def __init__(self, prev_node, next_node, key, value, callbacks=set()): + def __init__(self, prev_node, next_node, key, value, callbacks: set = None): self.prev_node = prev_node self.next_node = next_node self.key = key self.value = value - self.callbacks = callbacks + self.callbacks = callbacks or set() class LruCache(Generic[KT, VT]): @@ -176,10 +176,10 @@ def cache_len(): self.len = synchronized(cache_len) - def add_node(key, value, callbacks=set()): + def add_node(key, value, callbacks: set = None): prev_node = list_root next_node = prev_node.next_node - node = _Node(prev_node, next_node, key, value, callbacks) + node = _Node(prev_node, next_node, key, value, callbacks or set()) prev_node.next_node = node next_node.prev_node = node cache[key] = node @@ -237,7 +237,7 @@ def cache_get( def cache_get( key: KT, default: Optional[T] = None, - callbacks: Iterable[Callable[[], None]] = [], + callbacks: Iterable[Callable[[], None]] = frozenset(), update_metrics: bool = True, ): node = cache.get(key, None) @@ -253,7 +253,9 @@ def cache_get( return default @synchronized - def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []): + def cache_set( + key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = frozenset() + ): node = cache.get(key, None) if node is not None: # We sometimes store large objects, e.g. dicts, which cause diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 4c56253da549..70d4e7e6af48 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -180,7 +180,7 @@ def _make_get_request(self, uri): _check_logcontext(context) def _handle_well_known_connection( - self, client_factory, expected_sni, content, response_headers={} + self, client_factory, expected_sni, content, response_headers: dict = None ): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. @@ -202,10 +202,10 @@ def _handle_well_known_connection( self.assertEqual( request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] ) - self._send_well_known_response(request, content, headers=response_headers) + self._send_well_known_response(request, content, headers=response_headers or {}) return well_known_server - def _send_well_known_response(self, request, content, headers={}): + def _send_well_known_response(self, request, content, headers: dict = None): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ @@ -213,7 +213,7 @@ def _send_well_known_response(self, request, content, headers={}): self.assertEqual(request.path, b"/.well-known/matrix/server") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # send back a response - for k, v in headers.items(): + for k, v in (headers or {}).items(): request.setHeader(k, v) request.write(content) request.finish() diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 67b7913666fc..19b8c8f3b47c 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -266,7 +266,7 @@ def create_test_resource(self): return resource def make_worker_hs( - self, worker_app: str, extra_config: dict = {}, **kwargs + self, worker_app: str, extra_config: dict = None, **kwargs ) -> HomeServer: """Make a new worker HS instance, correctly connecting replcation stream to the master HS. @@ -283,7 +283,7 @@ def make_worker_hs( config = self._get_worker_hs_config() config["worker_app"] = worker_app - config.update(extra_config) + config.update(extra_config or {}) worker_hs = self.setup_test_homeserver( homeserver_to_use=GenericWorkerServer, diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 0ceb0f935cd4..9c2b984330a0 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Iterable from canonicaljson import encode_canonical_json @@ -332,15 +333,18 @@ def build_event( room_id=ROOM_ID, type="m.room.message", key=None, - internal={}, + internal: dict = None, depth=None, - prev_events=[], - auth_events=[], - prev_state=[], + prev_events: list = None, + auth_events: list = None, + prev_state: list = None, redacts=None, - push_actions=[], + push_actions: Iterable = frozenset(), **content ): + prev_events = prev_events or [] + auth_events = auth_events or [] + prev_state = prev_state or [] if depth is None: depth = self.event_id @@ -369,7 +373,7 @@ def build_event( if redacts is not None: event_dict["redacts"] = redacts - event = make_event_from_dict(event_dict, internal_metadata_dict=internal) + event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {}) self.event_id += 1 state_handler = self.hs.get_state_handler() diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ed65f645fc2c..715414a3107d 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -19,6 +19,7 @@ """Tests REST events for /rooms paths.""" import json +from typing import Iterable from urllib import parse as urlparse from mock import Mock @@ -207,7 +208,9 @@ def test_topic_perms(self): ) self.assertEquals(403, channel.code, msg=channel.result["body"]) - def _test_get_membership(self, room=None, members=[], expect_code=None): + def _test_get_membership( + self, room=None, members: Iterable = frozenset(), expect_code=None + ): for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) channel = self.make_request("GET", path) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 946740aa5d51..75d1512234ff 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -132,7 +132,7 @@ def change_membership( src: str, targ: str, membership: str, - extra_data: dict = {}, + extra_data: dict = None, tok: Optional[str] = None, expect_code: int = 200, ) -> None: @@ -156,7 +156,7 @@ def change_membership( path = path + "?access_token=%s" % tok data = {"membership": membership} - data.update(extra_data) + data.update(extra_data or {}) channel = make_request( self.hs.get_reactor(), @@ -187,7 +187,13 @@ def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): ) def send_event( - self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200 + self, + room_id, + type, + content: dict = None, + txn_id=None, + tok=None, + expect_code=200, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -201,7 +207,7 @@ def send_event( self.site, "PUT", path, - json.dumps(content).encode("utf8"), + json.dumps(content or {}).encode("utf8"), ) assert ( diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index e7bb5583fc48..f78500d0f03d 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -681,7 +681,7 @@ def _send_relation( relation_type, event_type, key=None, - content={}, + content: dict = None, access_token=None, parent_id=None, ): @@ -713,7 +713,7 @@ def _send_relation( "POST", "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" % (self.room, original_id, relation_type, event_type, query), - json.dumps(content).encode("utf-8"), + json.dumps(content or {}).encode("utf-8"), access_token=access_token, ) return channel diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index aad6bc907e43..12bbb867354f 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + from synapse.storage.database import DatabasePool from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -43,7 +45,7 @@ def _setup_db(self, txn): ) def _create_id_generator( - self, instance_name="master", writers=["master"] + self, instance_name="master", writers: List[str] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( @@ -53,7 +55,7 @@ def _create(conn): instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", - writers=writers, + writers=writers or ["master"], ) return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) @@ -476,7 +478,7 @@ def _setup_db(self, txn): ) def _create_id_generator( - self, instance_name="master", writers=["master"] + self, instance_name="master", writers: List[str] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( @@ -486,7 +488,7 @@ def _create(conn): instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", - writers=writers, + writers=writers or ["master"], positive=False, ) @@ -612,7 +614,7 @@ def _setup_db(self, txn): ) def _create_id_generator( - self, instance_name="master", writers=["master"] + self, instance_name="master", writers: List[str] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( @@ -625,7 +627,7 @@ def _create(conn): ("foobar2", "instance_name", "stream_id"), ], sequence_name="foobar_seq", - writers=writers, + writers=writers or ["master"], ) return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index b2a0e6085678..390c3e26ec34 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -50,10 +50,10 @@ def prepare(self, reactor, clock, hs): self.depth = 1 def inject_room_member( - self, room, user, membership, replaces_state=None, extra_content={} + self, room, user, membership, replaces_state=None, extra_content: dict = None ): content = {"membership": membership} - content.update(extra_content) + content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { diff --git a/tests/test_state.py b/tests/test_state.py index 6227a3ba9555..9436dd0f8ffd 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List from mock import Mock @@ -37,7 +38,7 @@ def create_event( state_key=None, depth=2, event_id=None, - prev_events=[], + prev_events: List[str] = None, **kwargs ): global _next_event_id @@ -58,7 +59,7 @@ def create_event( "sender": "@user_id:example.com", "room_id": "!room_id:example.com", "depth": depth, - "prev_events": prev_events, + "prev_events": prev_events or [], } if state_key is not None: diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 510b63011470..4f32638f4afb 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -147,9 +147,11 @@ def inject_visibility(self, user_id, visibility): return event @defer.inlineCallbacks - def inject_room_member(self, user_id, membership="join", extra_content={}): + def inject_room_member( + self, user_id, membership="join", extra_content: dict = None + ): content = {"membership": membership} - content.update(extra_content) + content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 4d1aee91d537..de18de81e675 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -89,9 +89,9 @@ def _await_resolution(reactor, d): return (reactor.seconds() - start_time) * 1000 -def build_rc_config(settings={}): +def build_rc_config(settings: dict = None): config_dict = default_config("test") - config_dict.update(settings) + config_dict.update(settings or {}) config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") return config.rc_federation From 44053d6327788d22d3dd73c4d4abc7c6666994f0 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 24 Mar 2021 20:55:38 +0100 Subject: [PATCH 2/9] news --- changelog.d/9682.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/9682.misc diff --git a/changelog.d/9682.misc b/changelog.d/9682.misc new file mode 100644 index 000000000000..428a466facfb --- /dev/null +++ b/changelog.d/9682.misc @@ -0,0 +1 @@ +Introduce flake8-bugbear to the test suite and fix some of its lint violations. From 3cfb8bbb53887c7aa0c64585a4afc71f7743358c Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 25 Mar 2021 19:09:25 +0100 Subject: [PATCH 3/9] small adjustments, fix wrong replacement --- synapse/handlers/message.py | 4 +++- synapse/storage/database.py | 1 + synapse/storage/databases/state/bg_updates.py | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e7946f3061a4..1209a43c5567 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -164,6 +164,8 @@ async def get_state_events( AuthError (403) if the user doesn't have permission to view members of this room. """ + state_filter = state_filter or StateFilter.all() + if at_token: # FIXME this claims to get the state at a stream position, but # get_recent_events_for_room operates by topo ordering. This therefore @@ -186,7 +188,7 @@ async def get_state_events( event = last_events[0] if visible_events: room_state = await self.state_store.get_state_for_events( - [event.event_id], state_filter=state_filter or StateFilter.all() + [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] else: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index cf42709d4482..b9ce0d3eaed4 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1538,6 +1538,7 @@ async def simple_select_many_batch( batch_size: the number of rows for each select query """ keyvalues = keyvalues or {} + results = [] # type: List[Dict[str, Any]] if not iterable: diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index df86b042c7b2..6d763f91783e 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -76,6 +76,7 @@ def _get_state_groups_from_groups_txn( self, txn, groups, state_filter: StateFilter = None ): state_filter = state_filter or StateFilter.all() + results = {group: {} for group in groups} where_clause, where_args = state_filter.make_sql_filter_clause() From d1b66682c47dbbddef68ef78c34b7554fe15e0a3 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sat, 3 Apr 2021 21:20:37 +0200 Subject: [PATCH 4/9] here come the Optional --- synapse/appservice/scheduler.py | 4 ++-- synapse/config/ratelimiting.py | 4 ++-- synapse/federation/units.py | 3 ++- synapse/handlers/account_validity.py | 4 ++-- synapse/handlers/appservice.py | 2 +- synapse/handlers/e2e_keys.py | 2 +- synapse/handlers/federation.py | 4 ++-- synapse/handlers/message.py | 6 +++--- synapse/handlers/register.py | 2 +- synapse/handlers/sync.py | 4 ++-- synapse/http/client.py | 4 ++-- synapse/logging/opentracing.py | 2 +- synapse/module_api/__init__.py | 4 ++-- synapse/notifier.py | 12 ++++++------ synapse/storage/database.py | 10 +++++----- synapse/storage/databases/main/events.py | 4 ++-- synapse/storage/state.py | 10 +++++----- synapse/util/caches/deferred_cache.py | 2 +- synapse/util/caches/lrucache.py | 4 ++-- tests/replication/_base.py | 2 +- tests/replication/slave/storage/test_events.py | 10 +++++----- tests/rest/client/v2_alpha/test_auth.py | 4 ++-- 22 files changed, 52 insertions(+), 51 deletions(-) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index f49ea611b417..5203ffe90fdd 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -49,7 +49,7 @@ components. """ import logging -from typing import List +from typing import List, Optional from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.events import EventBase @@ -191,7 +191,7 @@ async def send( self, service: ApplicationService, events: List[EventBase], - ephemeral: List[JsonDict] = None, + ephemeral: Optional[List[JsonDict]] = None, ): try: txn = await self.store.create_appservice_txn( diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 80adf5c42080..7a8d5851c40b 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict +from typing import Dict, Optional from ._base import Config @@ -21,7 +21,7 @@ class RateLimitConfig: def __init__( self, config: Dict[str, float], - defaults: Dict[str, float] = None, + defaults: Optional[Dict[str, float]] = None, ): defaults = defaults or {"per_second": 0.17, "burst_count": 3.0} diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 2e96494c678c..0f8bf000ac3d 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -18,6 +18,7 @@ """ import logging +from typing import Optional import attr @@ -98,7 +99,7 @@ class Transaction(JsonEncodedObject): "pdus", ] - def __init__(self, transaction_id=None, pdus: list = None, **kwargs): + def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs): """If we include a list of pdus then we decode then as PDU's automatically. """ diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d781bb251de8..3589c3360669 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -18,7 +18,7 @@ import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from synapse.api.errors import StoreError, SynapseError from synapse.logging.context import make_deferred_yieldable @@ -241,7 +241,7 @@ async def renew_account(self, renewal_token: str) -> bool: return True async def renew_account_for_user( - self, user_id: str, expiration_ts: int = None, email_sent: bool = False + self, user_id: str, expiration_ts: Optional[int] = None, email_sent: bool = False ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 95413956b661..9fb7ee335d38 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -182,7 +182,7 @@ def notify_interested_services_ephemeral( self, stream_key: str, new_token: Optional[int], - users: Collection[Union[str, UserID]] = None, + users: Optional[Collection[Union[str, UserID]]] = None, ): """This is called by the notifier in the background when a ephemeral event handled by the homeserver. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 2ad9b6d930e7..739653a3fa20 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1008,7 +1008,7 @@ async def _process_other_signatures( return signature_list, failures async def _get_e2e_cross_signing_verify_key( - self, user_id: str, key_type: str, from_user_id: str = None + self, user_id: str, key_type: str, from_user_id: Optional[str] = None ) -> Tuple[JsonDict, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 250c5250edb5..4500eeec9959 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1772,7 +1772,7 @@ async def _make_and_verify_event( room_id: str, user_id: str, membership: str, - content: JsonDict = None, + content: JsonDict, params: Optional[Dict[str, Union[str, Iterable[str]]]] = None, ) -> Tuple[str, EventBase, RoomVersion]: ( @@ -1780,7 +1780,7 @@ async def _make_and_verify_event( event, room_version, ) = await self.federation_client.make_membership_event( - target_hosts, room_id, user_id, membership, content or {}, params=params + target_hosts, room_id, user_id, membership, content, params=params ) logger.debug("Got response to make_%s: %s", membership, event) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 1209a43c5567..a998ee8b89cd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -137,7 +137,7 @@ async def get_state_events( self, user_id: str, room_id: str, - state_filter: StateFilter = None, + state_filter: Optional[StateFilter] = None, at_token: Optional[StreamToken] = None, is_guest: bool = False, ) -> List[dict]: @@ -876,7 +876,7 @@ async def handle_new_client_event( event: EventBase, context: EventContext, ratelimit: bool = True, - extra_users: List[UserID] = None, + extra_users: Optional[List[UserID]] = None, ignore_shadow_ban: bool = False, ) -> EventBase: """Processes a new event. @@ -1074,7 +1074,7 @@ async def persist_and_notify_client_event( event: EventBase, context: EventContext, ratelimit: bool = True, - extra_users: List[UserID] = None, + extra_users: Optional[List[UserID]] = None, ) -> EventBase: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index d1281d9b28f8..03ea9c424982 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -169,7 +169,7 @@ async def register_user( user_type: Optional[str] = None, default_display_name: Optional[str] = None, address: Optional[str] = None, - bind_emails: Iterable[str] = None, + bind_emails: Optional[Iterable[str]] = None, by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, auth_provider_id: Optional[str] = None, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index bf1c4aba12e2..fe2eb2e1e9f0 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -540,7 +540,7 @@ async def _load_filtered_recents( ) async def get_state_after_event( - self, event: EventBase, state_filter: StateFilter = None + self, event: EventBase, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """ Get the room state after the given event @@ -561,7 +561,7 @@ async def get_state_at( self, room_id: str, stream_position: StreamToken, - state_filter: StateFilter = None, + state_filter: Optional[StateFilter] = None, ) -> StateMap[str]: """Get the room state at a particular stream position diff --git a/synapse/http/client.py b/synapse/http/client.py index 5ce21ad458ce..f7a07f0466c5 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -297,7 +297,7 @@ class SimpleHttpClient: def __init__( self, hs: "HomeServer", - treq_args: Dict[str, Any] = None, + treq_args: Optional[Dict[str, Any]] = None, ip_whitelist: Optional[IPSet] = None, ip_blacklist: Optional[IPSet] = None, use_proxy: bool = False, @@ -590,7 +590,7 @@ async def put_json( uri: str, json_body: Any, args: Optional[QueryParams] = None, - headers: RawHeaders = None, + headers: Optional[RawHeaders] = None, ) -> Any: """Puts some json to the given URI. diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index f6a8c420b554..947c89650147 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -478,7 +478,7 @@ def start_active_span_from_request( def start_active_span_from_edu( edu_content, operation_name, - references: list = None, + references: Optional[list] = None, tags=None, start_time=None, ignore_active_span=False, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 2cc062cdf61d..5638b8ee2dbf 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -118,7 +118,7 @@ def check_user_exists(self, user_id): return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) @defer.inlineCallbacks - def register(self, localpart, displayname=None, emails: List[str] = None): + def register(self, localpart, displayname=None, emails: Optional[List[str]] = None): """Registers a new user with given localpart and optional displayname, emails. Also returns an access token for the new user. @@ -142,7 +142,7 @@ def register(self, localpart, displayname=None, emails: List[str] = None): _, access_token = yield self.register_device(user_id) return user_id, access_token - def register_user(self, localpart, displayname=None, emails: List[str] = None): + def register_user(self, localpart, displayname=None, emails: Optional[List[str]] = None): """Registers a new user with given localpart and optional displayname, emails. Args: diff --git a/synapse/notifier.py b/synapse/notifier.py index d7b9a8043c83..388a5692a3d1 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -266,7 +266,7 @@ def on_new_room_event( event: EventBase, event_pos: PersistedEventPosition, max_room_stream_token: RoomStreamToken, - extra_users: Collection[UserID] = None, + extra_users: Optional[Collection[UserID]] = None, ): """Unwraps event and calls `on_new_room_event_args`.""" self.on_new_room_event_args( @@ -287,7 +287,7 @@ def on_new_room_event_args( membership: Optional[str], event_pos: PersistedEventPosition, max_room_stream_token: RoomStreamToken, - extra_users: Collection[UserID] = None, + extra_users: Optional[Collection[UserID]] = None, ): """Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -372,7 +372,7 @@ def _notify_app_services_ephemeral( self, stream_key: str, new_token: Union[int, RoomStreamToken], - users: Collection[Union[str, UserID]] = None, + users: Optional[Collection[Union[str, UserID]]] = None, ): try: stream_token = None @@ -394,8 +394,8 @@ def on_new_event( self, stream_key: str, new_token: Union[int, RoomStreamToken], - users: Collection[Union[str, UserID]] = None, - rooms: Collection[str] = None, + users: Optional[Collection[Union[str, UserID]]] = None, + rooms: Optional[Collection[str]] = None, ): """Used to inform listeners that something has happened event wise. @@ -510,7 +510,7 @@ async def get_events_for( pagination_config: PaginationConfig, timeout: int, is_guest: bool = False, - explicit_room_id: str = None, + explicit_room_id: Optional[str] = None, ) -> EventStreamResult: """For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b9ce0d3eaed4..688ed464504f 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -900,7 +900,7 @@ async def simple_upsert( table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = None, + insertion_values: Optional[Dict[str, Any]] = None, desc: str = "simple_upsert", lock: bool = True, ) -> Optional[bool]: @@ -966,7 +966,7 @@ def simple_upsert_txn( table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = None, + insertion_values: Optional[Dict[str, Any]] = None, lock: bool = True, ) -> Optional[bool]: """ @@ -1007,7 +1007,7 @@ def simple_upsert_txn_emulated( table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = None, + insertion_values: Optional[Dict[str, Any]] = None, lock: bool = True, ) -> bool: """ @@ -1083,7 +1083,7 @@ def simple_upsert_txn_native_upsert( table: str, keyvalues: Dict[str, Any], values: Dict[str, Any], - insertion_values: Dict[str, Any] = None, + insertion_values: Optional[Dict[str, Any]] = None, ) -> None: """ Use the native UPSERT functionality in recent PostgreSQL versions. @@ -1519,7 +1519,7 @@ async def simple_select_many_batch( column: str, iterable: Iterable[Any], retcols: Iterable[str], - keyvalues: Dict[str, Any] = None, + keyvalues: Optional[Dict[str, Any]] = None, desc: str = "simple_select_many_batch", batch_size: int = 100, ) -> List[Any]: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 053daf8b6ae2..ad17123915b4 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -320,8 +320,8 @@ def _persist_events_txn( txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool, - state_delta_for_room: Dict[str, DeltaState] = None, - new_forward_extremeties: Dict[str, List[str]] = None, + state_delta_for_room: Optional[Dict[str, DeltaState]] = None, + new_forward_extremeties: Optional[Dict[str, List[str]]] = None, ): """Insert some number of room events into the necessary database tables. diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 2987b148471c..c1c147c62ac2 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -449,7 +449,7 @@ def _get_state_groups_from_groups( return self.stores.state._get_state_groups_from_groups(groups, state_filter) async def get_state_for_events( - self, event_ids: Iterable[str], state_filter: StateFilter = None + self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -485,7 +485,7 @@ async def get_state_for_events( return {event: event_to_state[event] for event in event_ids} async def get_state_ids_for_events( - self, event_ids: Iterable[str], state_filter: StateFilter = None + self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids @@ -513,7 +513,7 @@ async def get_state_ids_for_events( return {event: event_to_state[event] for event in event_ids} async def get_state_for_event( - self, event_id: str, state_filter: StateFilter = None + self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[EventBase]: """ Get the state dict corresponding to a particular event @@ -531,7 +531,7 @@ async def get_state_for_event( return state_map[event_id] async def get_state_ids_for_event( - self, event_id: str, state_filter: StateFilter = None + self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """ Get the state dict corresponding to a particular event @@ -549,7 +549,7 @@ async def get_state_ids_for_event( return state_map[event_id] def _get_state_for_groups( - self, groups: Iterable[int], state_filter: StateFilter = None + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Awaitable[Dict[int, MutableStateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 1adc92eb905f..0532c9bd0302 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -283,7 +283,7 @@ def eb(_fail): # we return a new Deferred which will be called before any subsequent observers. return observable.observe() - def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): + def prefill(self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None): callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index a609d25741b3..158de7abe513 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -57,7 +57,7 @@ def enumerate_leaves(node, depth): class _Node: __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] - def __init__(self, prev_node, next_node, key, value, callbacks: set = None): + def __init__(self, prev_node, next_node, key, value, callbacks: Optional[set] = None): self.prev_node = prev_node self.next_node = next_node self.key = key @@ -176,7 +176,7 @@ def cache_len(): self.len = synchronized(cache_len) - def add_node(key, value, callbacks: set = None): + def add_node(key, value, callbacks: Optional[set] = None): prev_node = list_root next_node = prev_node.next_node node = _Node(prev_node, next_node, key, value, callbacks or set()) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 19b8c8f3b47c..d4223c1e3749 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -266,7 +266,7 @@ def create_test_resource(self): return resource def make_worker_hs( - self, worker_app: str, extra_config: dict = None, **kwargs + self, worker_app: str, extra_config: Optional[dict] = None, **kwargs ) -> HomeServer: """Make a new worker HS instance, correctly connecting replcation stream to the master HS. diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 9c2b984330a0..333374b183cd 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Iterable +from typing import Iterable, Optional from canonicaljson import encode_canonical_json @@ -333,11 +333,11 @@ def build_event( room_id=ROOM_ID, type="m.room.message", key=None, - internal: dict = None, + internal: Optional[dict] = None, depth=None, - prev_events: list = None, - auth_events: list = None, - prev_state: list = None, + prev_events: Optional[list] = None, + auth_events: Optional[list] = None, + prev_state: Optional[list] = None, redacts=None, push_actions: Iterable = frozenset(), **content diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 9734a2159a1a..4534428d3e00 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Union, Optional from twisted.internet.defer import succeed @@ -74,7 +74,7 @@ def register(self, expected_response: int, body: JsonDict) -> FakeChannel: return channel def recaptcha( - self, session: str, expected_post_response: int, post_session: str = None + self, session: str, expected_post_response: int, post_session: Optional[str] = None ) -> None: """Get and respond to a fallback recaptcha. Returns the second request.""" if post_session is None: From 01c5c3377e37b79405b111a9f371ce8985fc1f83 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sat, 3 Apr 2021 21:23:49 +0200 Subject: [PATCH 5/9] grace ye linter gods --- synapse/handlers/account_validity.py | 5 ++++- synapse/module_api/__init__.py | 4 +++- synapse/util/caches/deferred_cache.py | 4 +++- synapse/util/caches/lrucache.py | 4 +++- tests/rest/client/v2_alpha/test_auth.py | 7 +++++-- 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 3589c3360669..bee1447c2ee5 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -241,7 +241,10 @@ async def renew_account(self, renewal_token: str) -> bool: return True async def renew_account_for_user( - self, user_id: str, expiration_ts: Optional[int] = None, email_sent: bool = False + self, + user_id: str, + expiration_ts: Optional[int] = None, + email_sent: bool = False, ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 5638b8ee2dbf..24b33ffae76a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -142,7 +142,9 @@ def register(self, localpart, displayname=None, emails: Optional[List[str]] = No _, access_token = yield self.register_device(user_id) return user_id, access_token - def register_user(self, localpart, displayname=None, emails: Optional[List[str]] = None): + def register_user( + self, localpart, displayname=None, emails: Optional[List[str]] = None + ): """Registers a new user with given localpart and optional displayname, emails. Args: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 0532c9bd0302..dd392cf69437 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -283,7 +283,9 @@ def eb(_fail): # we return a new Deferred which will be called before any subsequent observers. return observable.observe() - def prefill(self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None): + def prefill( + self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None + ): callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 158de7abe513..a77fbea0f79f 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -57,7 +57,9 @@ def enumerate_leaves(node, depth): class _Node: __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] - def __init__(self, prev_node, next_node, key, value, callbacks: Optional[set] = None): + def __init__( + self, prev_node, next_node, key, value, callbacks: Optional[set] = None + ): self.prev_node = prev_node self.next_node = next_node self.key = key diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 4534428d3e00..ed433d93334c 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Optional +from typing import Optional, Union from twisted.internet.defer import succeed @@ -74,7 +74,10 @@ def register(self, expected_response: int, body: JsonDict) -> FakeChannel: return channel def recaptcha( - self, session: str, expected_post_response: int, post_session: Optional[str] = None + self, + session: str, + expected_post_response: int, + post_session: Optional[str] = None, ) -> None: """Get and respond to a fallback recaptcha. Returns the second request.""" if post_session is None: From 62477776b4e16f89d19c8615dc18f19009b1cab7 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sat, 3 Apr 2021 21:55:16 +0200 Subject: [PATCH 6/9] buncha fixes and additions --- contrib/cmdclient/console.py | 3 ++- contrib/cmdclient/http.py | 9 ++++++--- synapse/events/__init__.py | 6 +++--- synapse/handlers/account_validity.py | 4 ++-- synapse/handlers/e2e_keys.py | 2 +- synapse/http/proxyagent.py | 2 +- synapse/notifier.py | 2 +- synapse/storage/databases/main/group_server.py | 2 +- synapse/storage/databases/main/state.py | 2 +- synapse/storage/databases/state/bg_updates.py | 3 ++- synapse/storage/databases/state/store.py | 4 ++-- synapse/util/caches/deferred_cache.py | 4 +--- tests/http/federation/test_matrix_federation_agent.py | 11 +++++++++-- tests/rest/client/v1/utils.py | 4 ++-- tests/rest/client/v2_alpha/test_auth.py | 2 +- tests/rest/client/v2_alpha/test_relations.py | 3 ++- tests/storage/test_id_generators.py | 8 ++++---- tests/storage/test_redaction.py | 8 +++++++- tests/test_state.py | 4 ++-- tests/util/test_ratelimitutils.py | 4 +++- 20 files changed, 53 insertions(+), 34 deletions(-) diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 243b658425f1..856dd437db91 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -24,6 +24,7 @@ import time import urllib from http import TwistedHttpClient +from typing import Optional import nacl.encoding import nacl.signing @@ -718,7 +719,7 @@ def _run_and_pprint( method, path, data=None, - query_params: dict = None, + query_params: Optional[dict] = None, alt_text=None, ): """Runs an HTTP request and pretty prints the output. diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index a381af7cc92f..e20dcaea363a 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -16,6 +16,7 @@ import json import urllib from pprint import pformat +from typing import Optional from twisted.internet import defer, reactor from twisted.web.client import Agent, readBody @@ -85,7 +86,7 @@ def get_json(self, url, args=None): body = yield readBody(response) defer.returnValue(json.loads(body)) - def _create_put_request(self, url, json_data, headers_dict: dict = None): + def _create_put_request(self, url, json_data, headers_dict: Optional[dict] = None): """Wrapper of _create_request to issue a PUT request""" headers_dict = headers_dict or {} @@ -96,7 +97,7 @@ def _create_put_request(self, url, json_data, headers_dict: dict = None): "PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict ) - def _create_get_request(self, url, headers_dict: dict = None): + def _create_get_request(self, url, headers_dict: Optional[dict] = None): """Wrapper of _create_request to issue a GET request""" return self._create_request("GET", url, headers_dict=headers_dict or {}) @@ -126,7 +127,9 @@ def do_request( defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def _create_request(self, method, url, producer=None, headers_dict: dict = None): + def _create_request( + self, method, url, producer=None, headers_dict: Optional[dict] = None + ): """Creates and sends a request to the given url""" headers_dict = headers_dict or {} diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 6abcab8374e2..f9032e36977f 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -330,7 +330,7 @@ def __init__( self, event_dict: JsonDict, room_version: RoomVersion, - internal_metadata_dict: JsonDict = None, + internal_metadata_dict: Optional[JsonDict] = None, rejected_reason: Optional[str] = None, ): internal_metadata_dict = internal_metadata_dict or {} @@ -388,7 +388,7 @@ def __init__( self, event_dict: JsonDict, room_version: RoomVersion, - internal_metadata_dict: JsonDict = None, + internal_metadata_dict: Optional[JsonDict] = None, rejected_reason: Optional[str] = None, ): internal_metadata_dict = internal_metadata_dict or {} @@ -511,7 +511,7 @@ def _event_type_from_format_version(format_version: int) -> Type[EventBase]: def make_event_from_dict( event_dict: JsonDict, room_version: RoomVersion = RoomVersions.V1, - internal_metadata_dict: JsonDict = None, + internal_metadata_dict: Optional[JsonDict] = None, rejected_reason: Optional[str] = None, ) -> EventBase: """Construct an EventBase from the given event dict""" diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index bee1447c2ee5..d3881b6845ee 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -18,7 +18,7 @@ import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List from synapse.api.errors import StoreError, SynapseError from synapse.logging.context import make_deferred_yieldable @@ -243,7 +243,7 @@ async def renew_account(self, renewal_token: str) -> bool: async def renew_account_for_user( self, user_id: str, - expiration_ts: Optional[int] = None, + expiration_ts: int = None, email_sent: bool = False, ) -> int: """Renews the account attached to a given user by pushing back the diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 739653a3fa20..2ad9b6d930e7 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1008,7 +1008,7 @@ async def _process_other_signatures( return signature_list, failures async def _get_e2e_cross_signing_verify_key( - self, user_id: str, key_type: str, from_user_id: Optional[str] = None + self, user_id: str, key_type: str, from_user_id: str = None ) -> Tuple[JsonDict, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 3d2aee933641..ea5ad14cb07c 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -88,7 +88,7 @@ def __init__( self, reactor, proxy_reactor=None, - contextFactory: IPolicyForHTTPS = None, + contextFactory: Optional[IPolicyForHTTPS] = None, connectTimeout=None, bindAddress=None, pool=None, diff --git a/synapse/notifier.py b/synapse/notifier.py index 7ce34380af3c..ee37b8757e6d 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -551,7 +551,7 @@ async def get_events_for( pagination_config: PaginationConfig, timeout: int, is_guest: bool = False, - explicit_room_id: Optional[str] = None, + explicit_room_id: str = None, ) -> EventStreamResult: """For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 47f02745abe5..215def208385 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1171,7 +1171,7 @@ async def register_user_group_membership( user_id: str, membership: str, is_admin: bool = False, - content: JsonDict = None, + content: Optional[JsonDict] = None, local_attestation: Optional[dict] = None, remote_attestation: Optional[dict] = None, is_publicised: bool = False, diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 6b6dd4135faa..93431efe0023 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -190,7 +190,7 @@ def _get_current_state_ids_txn(txn): # FIXME: how should this be cached? async def get_filtered_current_state_ids( - self, room_id: str, state_filter: StateFilter = None + self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 6d763f91783e..75c09b3687fd 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -73,7 +74,7 @@ def _count_state_group_hops_txn(self, txn, state_group): return count def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter: StateFilter = None + self, txn, groups, state_filter: Optional[StateFilter] = None ): state_filter = state_filter or StateFilter.all() diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 137afb99112d..dfcf89d91c9b 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Set, Tuple +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore @@ -210,7 +210,7 @@ def _get_state_for_group_using_cache(self, cache, group, state_filter): return state_filter.filter_state(state_dict_ids), not missing_types async def _get_state_for_groups( - self, groups: Iterable[int], state_filter: StateFilter = None + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index dd392cf69437..1adc92eb905f 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -283,9 +283,7 @@ def eb(_fail): # we return a new Deferred which will be called before any subsequent observers. return observable.observe() - def prefill( - self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None - ): + def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 70d4e7e6af48..73e12ea6c3ca 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from mock import Mock @@ -180,7 +181,11 @@ def _make_get_request(self, uri): _check_logcontext(context) def _handle_well_known_connection( - self, client_factory, expected_sni, content, response_headers: dict = None + self, + client_factory, + expected_sni, + content, + response_headers: Optional[dict] = None, ): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. @@ -205,7 +210,9 @@ def _handle_well_known_connection( self._send_well_known_response(request, content, headers=response_headers or {}) return well_known_server - def _send_well_known_response(self, request, content, headers: dict = None): + def _send_well_known_response( + self, request, content, headers: Optional[dict] = None + ): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 75d1512234ff..8a4dddae2b2a 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -132,7 +132,7 @@ def change_membership( src: str, targ: str, membership: str, - extra_data: dict = None, + extra_data: Optional[dict] = None, tok: Optional[str] = None, expect_code: int = 200, ) -> None: @@ -190,7 +190,7 @@ def send_event( self, room_id, type, - content: dict = None, + content: Optional[dict] = None, txn_id=None, tok=None, expect_code=200, diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index ed433d93334c..5cbcc4ab9efe 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -77,7 +77,7 @@ def recaptcha( self, session: str, expected_post_response: int, - post_session: Optional[str] = None, + post_session: str = None, ) -> None: """Get and respond to a fallback recaptcha. Returns the second request.""" if post_session is None: diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index f78500d0f03d..21ee436b919b 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -16,6 +16,7 @@ import itertools import json import urllib +from typing import Optional from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin @@ -681,7 +682,7 @@ def _send_relation( relation_type, event_type, key=None, - content: dict = None, + content: Optional[dict] = None, access_token=None, parent_id=None, ): diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 12bbb867354f..6c389fe9acd8 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from synapse.storage.database import DatabasePool from synapse.storage.engines import IncorrectDatabaseSetup @@ -45,7 +45,7 @@ def _setup_db(self, txn): ) def _create_id_generator( - self, instance_name="master", writers: List[str] = None + self, instance_name="master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( @@ -478,7 +478,7 @@ def _setup_db(self, txn): ) def _create_id_generator( - self, instance_name="master", writers: List[str] = None + self, instance_name="master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( @@ -614,7 +614,7 @@ def _setup_db(self, txn): ) def _create_id_generator( - self, instance_name="master", writers: List[str] = None + self, instance_name="master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 390c3e26ec34..92017beed95e 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from canonicaljson import json @@ -50,7 +51,12 @@ def prepare(self, reactor, clock, hs): self.depth = 1 def inject_room_member( - self, room, user, membership, replaces_state=None, extra_content: dict = None + self, + room, + user, + membership, + replaces_state=None, + extra_content: Optional[dict] = None, ): content = {"membership": membership} content.update(extra_content or {}) diff --git a/tests/test_state.py b/tests/test_state.py index 9436dd0f8ffd..1d2019699df6 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from mock import Mock @@ -38,7 +38,7 @@ def create_event( state_key=None, depth=2, event_id=None, - prev_events: List[str] = None, + prev_events: Optional[List[str]] = None, **kwargs ): global _next_event_id diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index de18de81e675..3fed55090a7c 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from synapse.config.homeserver import HomeServerConfig from synapse.util.ratelimitutils import FederationRateLimiter @@ -89,7 +91,7 @@ def _await_resolution(reactor, d): return (reactor.seconds() - start_time) * 1000 -def build_rc_config(settings: dict = None): +def build_rc_config(settings: Optional[dict] = None): config_dict = default_config("test") config_dict.update(settings or {}) config = HomeServerConfig() From 27f7a2e202ac8190cbd05eb4f24dbae30e2bf2f9 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sat, 3 Apr 2021 22:08:20 +0200 Subject: [PATCH 7/9] revert some, add some --- synapse/handlers/account_validity.py | 5 +---- tests/rest/client/v2_alpha/test_auth.py | 7 ++----- tests/test_visibility.py | 3 ++- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d3881b6845ee..d781bb251de8 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -241,10 +241,7 @@ async def renew_account(self, renewal_token: str) -> bool: return True async def renew_account_for_user( - self, - user_id: str, - expiration_ts: int = None, - email_sent: bool = False, + self, user_id: str, expiration_ts: int = None, email_sent: bool = False ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 5cbcc4ab9efe..9734a2159a1a 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Union from twisted.internet.defer import succeed @@ -74,10 +74,7 @@ def register(self, expected_response: int, body: JsonDict) -> FakeChannel: return channel def recaptcha( - self, - session: str, - expected_post_response: int, - post_session: str = None, + self, session: str, expected_post_response: int, post_session: str = None ) -> None: """Get and respond to a fallback recaptcha. Returns the second request.""" if post_session is None: diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 4f32638f4afb..1b4dd47a8238 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from mock import Mock @@ -148,7 +149,7 @@ def inject_visibility(self, user_id, visibility): @defer.inlineCallbacks def inject_room_member( - self, user_id, membership="join", extra_content: dict = None + self, user_id, membership="join", extra_content: Optional[dict] = None ): content = {"membership": membership} content.update(extra_content or {}) From 0dee2984849a8797366c038c49d270bec61340e4 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 8 Apr 2021 21:14:43 +0200 Subject: [PATCH 8/9] Apply suggestions from code review Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- contrib/cmdclient/http.py | 2 +- synapse/storage/util/id_generators.py | 2 +- synapse/util/caches/lrucache.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index e20dcaea363a..820ad56b6395 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -103,7 +103,7 @@ def _create_get_request(self, url, headers_dict: Optional[dict] = None): @defer.inlineCallbacks def do_request( - self, method, url, data=None, qparams=None, jsonreq=True, headers: dict = None + self, method, url, data=None, qparams=None, jsonreq=True, headers: Optional[dict] = None ): headers = headers or {} diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 0d8653dda806..32d6cc16b9a6 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -96,7 +96,7 @@ def __init__( db_conn, table, column, - extra_tables: Iterable[Tuple] = frozenset(), + extra_tables: Iterable[Tuple[str, str]] = (), step=1, ): assert step != 0 diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index a77fbea0f79f..8131646466f6 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -239,7 +239,7 @@ def cache_get( def cache_get( key: KT, default: Optional[T] = None, - callbacks: Iterable[Callable[[], None]] = frozenset(), + callbacks: Iterable[Callable[[], None]] = (), update_metrics: bool = True, ): node = cache.get(key, None) @@ -256,7 +256,7 @@ def cache_get( @synchronized def cache_set( - key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = frozenset() + key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = () ): node = cache.get(key, None) if node is not None: From d1c4b52bd24ec4d2d42338c6eb706b87ecc76c3d Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 8 Apr 2021 21:19:09 +0200 Subject: [PATCH 9/9] tactical linter petting --- contrib/cmdclient/http.py | 8 +++++++- synapse/util/caches/lrucache.py | 4 +--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index 820ad56b6395..1cf913756e6a 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -103,7 +103,13 @@ def _create_get_request(self, url, headers_dict: Optional[dict] = None): @defer.inlineCallbacks def do_request( - self, method, url, data=None, qparams=None, jsonreq=True, headers: Optional[dict] = None + self, + method, + url, + data=None, + qparams=None, + jsonreq=True, + headers: Optional[dict] = None, ): headers = headers or {} diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 8131646466f6..20c8e2d9f5ec 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -255,9 +255,7 @@ def cache_get( return default @synchronized - def cache_set( - key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = () - ): + def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()): node = cache.get(key, None) if node is not None: # We sometimes store large objects, e.g. dicts, which cause