From 6382598b699822ad8cd20b0c9cc3ce6d7e1b4478 Mon Sep 17 00:00:00 2001 From: GnomedDev Date: Sat, 12 Apr 2025 15:38:12 +0100 Subject: [PATCH] Rework Id -> Channel methods --- examples/testing/src/main.rs | 4 +- src/model/channel/channel_id.rs | 76 ++++++++++++++++++++++----------- src/model/channel/message.rs | 25 ++++++++--- src/model/channel/thread.rs | 53 +++++++++++++++++++++++ 4 files changed, 125 insertions(+), 33 deletions(-) diff --git a/examples/testing/src/main.rs b/examples/testing/src/main.rs index aff0340c31c..bca51e25137 100644 --- a/examples/testing/src/main.rs +++ b/examples/testing/src/main.rs @@ -167,7 +167,7 @@ async fn message(ctx: &Context, msg: &Message) -> Result<(), serenity::Error> { .await?; } else if msg.content == "assigntags" { let forum_id = msg.guild_channel(&ctx).await?.parent_id.unwrap(); - let forum = forum_id.widen().to_guild_channel(&ctx, msg.guild_id).await?; + let forum = forum_id.to_guild_channel(&ctx, msg.guild_id).await?; channel_id .expect_thread() .edit( @@ -202,7 +202,7 @@ async fn message(ctx: &Context, msg: &Message) -> Result<(), serenity::Error> { msg.author.id.dm(&ctx.http, builder).await?; } else if let Some(channel) = msg.content.strip_prefix("movetorootandback") { let mut channel = { - let channel_id = channel.trim().parse::().unwrap(); + let channel_id = channel.trim().parse::().unwrap(); channel_id.to_guild_channel(&ctx, msg.guild_id).await.unwrap() }; diff --git a/src/model/channel/channel_id.rs b/src/model/channel/channel_id.rs index 236975eba76..cbad3545661 100644 --- a/src/model/channel/channel_id.rs +++ b/src/model/channel/channel_id.rs @@ -22,6 +22,8 @@ use crate::builder::{ }; #[cfg(all(feature = "cache", feature = "model"))] use crate::cache::Cache; +#[cfg(all(feature = "cache", feature = "temp_cache", feature = "model"))] +use crate::cache::MaybeOwnedArc; #[cfg(feature = "model")] use crate::http::{CacheHttp, Http, Typing}; use crate::model::prelude::*; @@ -39,6 +41,55 @@ impl ChannelId { #[cfg(feature = "model")] impl ChannelId { + /// Fetches a channel from the cache, falling back to HTTP/temp cache. + /// + /// It is highly recommended to pass the `guild_id` parameter as otherwise this may perform many + /// HTTP requests. + /// + /// # Errors + /// + /// Errors if the HTTP fallback fails, or if the channel does not come from the guild passed. + pub async fn to_guild_channel( + self, + cache_http: impl CacheHttp, + guild_id: Option, + ) -> Result { + #[cfg(feature = "cache")] + if let Some(cache) = cache_http.cache() { + if let Some(guild_id) = guild_id { + if let Some(guild) = cache.guild(guild_id) { + if let Some(channel) = guild.channels.get(&self) { + return Ok(channel.clone()); + } + } + } + + #[cfg(feature = "temp_cache")] + if let Some(temp_channel) = cache.temp_channels.get(&self) { + if guild_id.is_some_and(|id| temp_channel.base.guild_id != id) { + return Err(Error::Model(ModelError::ChannelNotFound)); + } + + return Ok(GuildChannel::clone(&temp_channel)); + } + } + + let channel = cache_http.http().get_channel(self.widen()).await?; + let guild_channel = channel.guild().ok_or(ModelError::InvalidChannelType)?; + + #[cfg(all(feature = "cache", feature = "temp_cache"))] + if let Some(cache) = cache_http.cache() { + let cached_channel = MaybeOwnedArc::new(guild_channel.clone()); + cache.temp_channels.insert(self, cached_channel); + } + + if guild_id.is_some_and(|id| guild_channel.base.guild_id != id) { + return Err(Error::Model(ModelError::ChannelNotFound)); + } + + Ok(guild_channel) + } + /// Creates an invite for the given channel. /// /// **Note**: Requires the [Create Instant Invite] permission. @@ -595,8 +646,6 @@ impl GenericChannelId { #[cfg(all(feature = "cache", feature = "temp_cache"))] if let Some(cache) = cache_http.cache() { - use crate::cache::MaybeOwnedArc; - match &channel { Channel::Guild(guild_channel) => { let cached_channel = MaybeOwnedArc::new(guild_channel.clone()); @@ -614,29 +663,6 @@ impl GenericChannelId { Ok(channel) } - /// Fetches a channel from the cache, falling back to HTTP/temp cache. - /// - /// It is highly recommended to pass the `guild_id` parameter as otherwise this may perform many - /// HTTP requests. - /// - /// # Errors - /// - /// Errors if the HTTP fallback fails, or if the channel does not come from the guild passed. - pub async fn to_guild_channel( - self, - cache_http: impl CacheHttp, - guild_id: Option, - ) -> Result { - let channel = self.to_channel(cache_http, guild_id).await?; - let guild_channel = channel.guild().ok_or(ModelError::InvalidChannelType)?; - - if guild_id.is_some_and(|id| guild_channel.base.guild_id != id) { - return Err(Error::Model(ModelError::ChannelNotFound)); - } - - Ok(guild_channel) - } - /// Gets a message from the channel. /// /// If the cache feature is enabled the cache will be checked first. If not found it will diff --git a/src/model/channel/message.rs b/src/model/channel/message.rs index 6872066ac56..a1846d18f1a 100644 --- a/src/model/channel/message.rs +++ b/src/model/channel/message.rs @@ -169,8 +169,9 @@ impl Message { self.channel_id.expect_channel().crosspost(http, self.id).await } - /// First attempts to find a [`Channel`] by its Id in the cache, upon failure requests it via - /// HTTP. + /// Retrieves the [`Channel`] the message was sent in. + /// + /// See [`GenericChannelId::to_channel`] for information about how this is retrieved. /// /// # Errors /// @@ -179,14 +180,26 @@ impl Message { self.channel_id.to_channel(cache_http, self.guild_id).await } - /// First attempts to find the [`GuildChannel`] by it's Id in the cache, upon failure requests - /// it via HTTP. + /// Retrieves the [`GuildChannel`] the message was sent in. + /// + /// See [`ChannelId::to_guild_channel`] for information on how this is retrieved. /// /// # Errors /// - /// Can return an error if the HTTP request fails, or this is executed in a DM channel. + /// Can return an error if the HTTP request fails, or this is not called in a guild channel. pub async fn guild_channel(&self, cache_http: impl CacheHttp) -> Result { - self.channel_id.to_guild_channel(cache_http, self.guild_id).await + self.channel_id.expect_channel().to_guild_channel(cache_http, self.guild_id).await + } + + /// Retrieves the [`GuildThread`] the message was sent in. + /// + /// See [`ThreadId::to_thread`] for information on how this is retrieved. + /// + /// # Errors + /// + /// Can return an error if the HTTP request fails, or this is not called in a guild thread. + pub async fn guild_thread(&self, cache_http: impl CacheHttp) -> Result { + self.channel_id.expect_thread().to_thread(cache_http, self.guild_id).await } /// Calculates the permissions of the message author in the current channel. diff --git a/src/model/channel/thread.rs b/src/model/channel/thread.rs index 0542c6ef8a2..6c92a609a03 100644 --- a/src/model/channel/thread.rs +++ b/src/model/channel/thread.rs @@ -1,6 +1,8 @@ use super::*; #[cfg(feature = "model")] use crate::builder::{CreateMessage, EditThread}; +#[cfg(feature = "model")] +use crate::http::CacheHttp; use crate::internal::prelude::*; use crate::model::utils::is_false; @@ -42,6 +44,57 @@ impl ThreadId { #[cfg(feature = "model")] impl ThreadId { + /// Fetches a thread from the cache, falling back to HTTP/temp cache. + /// + /// It is highly recommended to pass the `guild_id` parameter as otherwise this may perform many + /// HTTP requests. + /// + /// # Errors + /// + /// Errors if the HTTP fallback fails, or if the channel does not come from the guild passed. + pub async fn to_thread( + self, + cache_http: impl CacheHttp, + guild_id: Option, + ) -> Result { + #[cfg(feature = "cache")] + if let Some(cache) = cache_http.cache() { + if let Some(guild_id) = guild_id { + if let Some(guild) = cache.guild(guild_id) { + if let Some(thread) = guild.threads.get(&self) { + return Ok(thread.clone()); + } + } + } + + #[cfg(feature = "temp_cache")] + if let Some(temp_thread) = cache.temp_threads.get(&self) { + if guild_id.is_some_and(|id| temp_thread.base.guild_id != id) { + return Err(Error::Model(ModelError::ChannelNotFound)); + } + + return Ok(GuildThread::clone(&temp_thread)); + } + } + + let channel = cache_http.http().get_channel(self.widen()).await?; + let guild_thread = channel.thread().ok_or(ModelError::InvalidChannelType)?; + + #[cfg(all(feature = "cache", feature = "temp_cache"))] + if let Some(cache) = cache_http.cache() { + use crate::cache::wrappers::MaybeOwnedArc; + + let cached_thread = MaybeOwnedArc::new(guild_thread.clone()); + cache.temp_threads.insert(self, cached_thread); + } + + if guild_id.is_some_and(|id| guild_thread.base.guild_id != id) { + return Err(Error::Model(ModelError::ChannelNotFound)); + } + + Ok(guild_thread) + } + /// Gets the thread members, if this channel is a thread. /// /// # Errors