diff --git a/matrix_reminder_bot/callbacks.py b/matrix_reminder_bot/callbacks.py index ba1f7fe..64e7dde 100644 --- a/matrix_reminder_bot/callbacks.py +++ b/matrix_reminder_bot/callbacks.py @@ -14,7 +14,7 @@ from matrix_reminder_bot.bot_commands import Command from matrix_reminder_bot.config import CONFIG from matrix_reminder_bot.errors import CommandError -from matrix_reminder_bot.functions import send_text_to_room +from matrix_reminder_bot.functions import is_allowed_user, send_text_to_room from matrix_reminder_bot.storage import Storage logger = logging.getLogger(__name__) @@ -71,6 +71,13 @@ async def message(self, room: MatrixRoom, event: RoomMessageText): if event.sender == self.client.user: return + # Ignore messages from disallowed users + if not is_allowed_user(event.sender): + logger.debug( + f"Ignoring event {event.event_id} in room {room.room_id} as the sender {event.sender} is not allowed." + ) + return + # Ignore broken events if not event.body: return @@ -123,6 +130,11 @@ async def invite(self, room: MatrixRoom, event: InviteMemberEvent): """Callback for when an invite is received. Join the room specified in the invite""" logger.debug(f"Got invite to {room.room_id} from {event.sender}.") + # Don't respond to invites from disallowed users + if not is_allowed_user(event.sender): + logger.debug(f"{event.sender} is not allowed, not responding to invite.") + return + # Attempt to join 3 times before giving up for attempt in range(3): result = await self.client.join(room.room_id) diff --git a/matrix_reminder_bot/config.py b/matrix_reminder_bot/config.py index 220d980..b4d1aa1 100644 --- a/matrix_reminder_bot/config.py +++ b/matrix_reminder_bot/config.py @@ -42,6 +42,12 @@ def __init__(self): self.timezone: str = "" + self.allowlist_enabled: bool = False + self.allowlist_regexes: list[re.Pattern] = [] + + self.blocklist_enabled: bool = False + self.blocklist_regexes: list[re.Pattern] = [] + def read_config(self, filepath: str): if not os.path.isfile(filepath): raise ConfigError(f"Config file '{filepath}' does not exist") @@ -122,6 +128,65 @@ def read_config(self, filepath: str): # Reminder configuration self.timezone = self._get_cfg(["reminders", "timezone"], default="Etc/UTC") + # Allowlist configuration + allowlist_enabled = self._get_cfg(["allowlist", "enabled"], required=True) + if not isinstance(allowlist_enabled, bool): + raise ConfigError("allowlist.enabled must be a boolean value") + self.allowlist_enabled = allowlist_enabled + + self.allowlist_regexes = self._compile_regexes( + ["allowlist", "regexes"], required=True + ) + + # Blocklist configuration + blocklist_enabled = self._get_cfg(["blocklist", "enabled"], required=True) + if not isinstance(blocklist_enabled, bool): + raise ConfigError("blocklist.enabled must be a boolean value") + self.blocklist_enabled = blocklist_enabled + + self.blocklist_regexes = self._compile_regexes( + ["blocklist", "regexes"], required=True + ) + + def _compile_regexes( + self, path: list[str], required: bool = True + ) -> list[re.Pattern]: + """Compile a config option containing a list of strings into re.Pattern objects. + + Args: + path: The path to the config option. + required: True, if the config option is mandatory. + + Returns: + A list of re.Pattern objects. + + Raises: + ConfigError: + - If required is specified, but the config option does not exist. + - If the config option is not a list of strings. + - If the config option contains an invalid regular expression. + """ + + readable_path = ".".join(path) + regex_strings = self._get_cfg(path, required=required) # raises ConfigError + + if not isinstance(regex_strings, list) or ( + isinstance(regex_strings, list) + and any(not isinstance(x, str) for x in regex_strings) + ): + raise ConfigError(f"{readable_path} must be a list of strings") + + compiled_regexes = [] + for regex in regex_strings: + try: + compiled_regexes.append(re.compile(regex)) + except re.error as e: + raise ConfigError( + f"'{e.pattern}' contained in {readable_path} is not a valid regular expression" + ) + + return compiled_regexes + def _get_cfg( self, path: List[str], diff --git a/matrix_reminder_bot/functions.py b/matrix_reminder_bot/functions.py index f2dbbe7..db6affc 100644 --- a/matrix_reminder_bot/functions.py +++ b/matrix_reminder_bot/functions.py @@ -115,3 +115,29 @@ def make_pill(user_id: str, displayname: str = None) -> str: displayname = user_id return f'{displayname}' + + +def is_allowed_user(user_id: str) -> bool: + """Returns if the bot is allowed to interact with the given user + + Args: + user_id: The MXID of the user. + + Returns: + True, if the bot is allowed to interact with the given user. + """ + allowed = not CONFIG.allowlist_enabled + + if CONFIG.allowlist_enabled: + for regex in CONFIG.allowlist_regexes: + if regex.fullmatch(user_id): + allowed = True + break + + if CONFIG.blocklist_enabled: + for regex in CONFIG.blocklist_regexes: + if regex.fullmatch(user_id): + allowed = False + break + + return allowed diff --git a/sample.config.yaml b/sample.config.yaml index 7062f54..6e9438f 100644 --- a/sample.config.yaml +++ b/sample.config.yaml @@ -36,6 +36,33 @@ reminders: # If not set, UTC will be used #timezone: "Europe/London" +# Restrict the bot to only respond to certain MXIDs +allowlist: + # Set to true to enable the allowlist + enabled: false + # A list of MXID regexes to be allowed + # To allow a certain homeserver: + # regexes: ["@[a-z0-9-_.]+:myhomeserver.tld"] + # To allow a set of users: + # regexes: ["@alice:someserver.tld", "@bob:anotherserver.tld"] + # To allow nobody (same as blocking every MXID): + # regexes: [] + regexes: [] + +# Prevent the bot from responding to certain MXIDs +# If both allowlist and blocklist are enabled, blocklist entries takes precedence +blocklist: + # Set to true to enable the blocklist + enabled: false + # A list of MXID regexes to be blocked + # To block a certain homeserver: + # regexes: [".*:myhomeserver.tld"] + # To block a set of users: + # regexes: ["@alice:someserver.tld", "@bob:anotherserver.tld"] + # To block absolutely everyone (same as allowing nobody): + # regexes: [".*"] + regexes: [] + # Logging setup logging: # Logging level