From 32b6ff9b6256a58d90ea0fd170007f0d577b4bd6 Mon Sep 17 00:00:00 2001 From: LordOfPolls Date: Thu, 11 Aug 2022 10:33:51 +0100 Subject: [PATCH 1/2] feat: add support for getting a threads message --- naff/api/events/processors/thread_events.py | 5 ++++- naff/models/discord/channel.py | 5 +++++ naff/models/discord/message.py | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/naff/api/events/processors/thread_events.py b/naff/api/events/processors/thread_events.py index 5b4886111..a62131af1 100644 --- a/naff/api/events/processors/thread_events.py +++ b/naff/api/events/processors/thread_events.py @@ -14,7 +14,10 @@ class ThreadEvents(EventMixinTemplate): @Processor.define() async def _on_raw_thread_create(self, event: "RawGatewayEvent") -> None: - self.dispatch(events.ThreadCreate(self.cache.place_channel_data(event.data))) + thread = self.cache.place_channel_data(event.data) + if message := self.cache.get_message(event.data["parent_id"], event.data["id"]): + message.thread = thread + self.dispatch(events.ThreadCreate(thread)) @Processor.define() async def _on_raw_thread_update(self, event: "RawGatewayEvent") -> None: diff --git a/naff/models/discord/channel.py b/naff/models/discord/channel.py index 447128e34..f8089d144 100644 --- a/naff/models/discord/channel.py +++ b/naff/models/discord/channel.py @@ -1802,6 +1802,11 @@ def parent_channel(self) -> Union[GuildText, "GuildForum"]: """The channel this thread is a child of.""" return self._client.cache.get_channel(self.parent_id) + @property + def parent_message(self) -> Optional["Message"]: + """The message this thread is a child of.""" + return self._client.cache.get_message(self.parent_id, self.id) + @property def mention(self) -> str: """Returns a string that would mention this thread.""" diff --git a/naff/models/discord/message.py b/naff/models/discord/message.py index 6558b5f9b..6ff902f81 100644 --- a/naff/models/discord/message.py +++ b/naff/models/discord/message.py @@ -313,6 +313,7 @@ class Message(BaseMessage): """Sent if the message contains components like buttons, action rows, or other interactive components""" sticker_items: Optional[List["models.StickerItem"]] = field(default=None) """Sent if the message contains stickers""" + thread: Optional["models.ThreadChannel"] = field(default=None) _mention_ids: List["Snowflake_Type"] = field(factory=list) _mention_roles: List["Snowflake_Type"] = field(factory=list) _referenced_message_id: Optional["Snowflake_Type"] = field(default=None) From 96ef663e0dc426199bca47fe9024972bb8a18857 Mon Sep 17 00:00:00 2001 From: LordOfPolls Date: Thu, 11 Aug 2022 10:37:38 +0100 Subject: [PATCH 2/2] refactor: use property not attribute --- naff/api/events/processors/thread_events.py | 5 +---- naff/models/discord/message.py | 6 +++++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/naff/api/events/processors/thread_events.py b/naff/api/events/processors/thread_events.py index a62131af1..5b4886111 100644 --- a/naff/api/events/processors/thread_events.py +++ b/naff/api/events/processors/thread_events.py @@ -14,10 +14,7 @@ class ThreadEvents(EventMixinTemplate): @Processor.define() async def _on_raw_thread_create(self, event: "RawGatewayEvent") -> None: - thread = self.cache.place_channel_data(event.data) - if message := self.cache.get_message(event.data["parent_id"], event.data["id"]): - message.thread = thread - self.dispatch(events.ThreadCreate(thread)) + self.dispatch(events.ThreadCreate(self.cache.place_channel_data(event.data))) @Processor.define() async def _on_raw_thread_update(self, event: "RawGatewayEvent") -> None: diff --git a/naff/models/discord/message.py b/naff/models/discord/message.py index 6ff902f81..e6a04ed7e 100644 --- a/naff/models/discord/message.py +++ b/naff/models/discord/message.py @@ -313,7 +313,6 @@ class Message(BaseMessage): """Sent if the message contains components like buttons, action rows, or other interactive components""" sticker_items: Optional[List["models.StickerItem"]] = field(default=None) """Sent if the message contains stickers""" - thread: Optional["models.ThreadChannel"] = field(default=None) _mention_ids: List["Snowflake_Type"] = field(factory=list) _mention_roles: List["Snowflake_Type"] = field(factory=list) _referenced_message_id: Optional["Snowflake_Type"] = field(default=None) @@ -330,6 +329,11 @@ async def mention_roles(self) -> AsyncGenerator["models.Role", None]: for r_id in self._mention_roles: yield await self._client.cache.fetch_role(self._guild_id, r_id) + @property + def thread(self) -> "models.TYPE_THREAD_CHANNEL": + """The thread that was started from this message, if any""" + return self._client.cache.get_channel(self.id) + async def fetch_referenced_message(self) -> Optional["Message"]: """ Fetch the message this message is referencing, if any.