Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 18 additions & 24 deletions src/gateway/sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub use self::shard_runner::{ShardRunner, ShardRunnerMessage, ShardRunnerOptions
use super::{ActivityData, ChunkGuildFilter, GatewayError, PresenceData, WsClient};
use crate::constants::{self, CloseCode};
use crate::internal::prelude::*;
use crate::model::event::{Event, GatewayEvent};
use crate::model::event::{DeserializedEvent, Event, GatewayEvent, UnknownEvent};
use crate::model::gateway::{GatewayIntents, ShardInfo};
use crate::model::id::{ApplicationId, GuildId, ShardId};
use crate::model::user::OnlineStatus;
Expand Down Expand Up @@ -312,13 +312,24 @@ impl Shard {
}

#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
fn handle_gateway_dispatch(&mut self, seq: u64, event: &[u8]) -> Result<Event> {
fn handle_gateway_dispatch(&mut self, seq: u64, event: DeserializedEvent) -> Option<Event> {
if seq > self.seq + 1 {
warn!("[{:?}] Sequence off; them: {}, us: {}", self.shard_info, seq, self.seq);
}

self.seq = seq;
let event = deserialize_and_log_event(event)?;

let event = match event {
DeserializedEvent::Success(event) => event,
DeserializedEvent::Unknown(UnknownEvent {
ty,
ref data,
}) => {
debug!("Unknown event: {ty}");
debug!("Failing event data: {data:?}");
return None;
},
};

match &event {
Event::Ready(ready) => {
Expand All @@ -345,7 +356,7 @@ impl Shard {
_ => {},
}

Ok(event)
Some(event)
}

#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
Expand Down Expand Up @@ -442,9 +453,9 @@ impl Shard {
Ok(GatewayEvent::Dispatch {
seq,
event,
}) => self
.handle_gateway_dispatch(seq, &event)
.map(|e| Some(ShardAction::Dispatch(Box::new(e)))),
}) => Ok(self
.handle_gateway_dispatch(seq, *event)
.map(|e| ShardAction::Dispatch(Box::new(e)))),
Ok(GatewayEvent::Heartbeat) => {
info!("[{:?}] Received shard heartbeat", self.shard_info);

Expand Down Expand Up @@ -736,23 +747,6 @@ async fn connect(base_url: &str, compression: TransportCompression) -> Result<Ws
WsClient::connect(url, compression).await
}

fn deserialize_and_log_event(event: &[u8]) -> Result<Event> {
serde_json::from_slice(event).map_err(|err| {
let err_dbg = format!("{err:?}");
if let Some((variant_name, _)) =
err_dbg.strip_prefix(r#"Error("unknown variant `"#).and_then(|s| s.split_once('`'))
{
debug!("Unknown event: {variant_name}");
} else {
warn!("Err deserializing text: {err_dbg}");
}

let event_str = String::from_utf8_lossy(event);
debug!("Failing event data: {event_str}");
Error::Json(err)
})
}

struct ResumeMetadata {
session_id: FixedString,
resume_ws_url: FixedString,
Expand Down
1 change: 0 additions & 1 deletion src/gateway/sharding/shard_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,6 @@ impl ShardRunner {

return Err(Error::Gateway(why));
},
Err(Error::Json(_)) => return Ok(None),
Err(why) => {
error!("Shard handler recieved err: {why:?}");
return Ok(None);
Expand Down
11 changes: 1 addition & 10 deletions src/gateway/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,16 +264,7 @@ impl WsClient {
};

match serde_json::from_slice(json_bytes) {
Ok(mut event) => {
if let GatewayEvent::Dispatch {
ref mut event, ..
} = event
{
*event = json_bytes.to_vec();
}

Ok(Some(event))
},
Ok(event) => Ok(Some(event)),
Err(err) => {
debug!("Failing text: {}", String::from_utf8_lossy(json_bytes));
Err(Error::Json(err))
Expand Down
48 changes: 39 additions & 9 deletions src/model/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -932,9 +932,7 @@ pub struct MessagePollVoteRemoveEvent {
pub enum GatewayEvent {
Dispatch {
seq: u64,
// Avoid deserialising straight away to handle errors and get access to `seq`.
// This must be filled in with original data by the caller after deserialisation.
event: Vec<u8>,
event: Box<DeserializedEvent>,
},
Heartbeat,
Reconnect,
Expand All @@ -944,6 +942,32 @@ pub enum GatewayEvent {
HeartbeatAck,
}

#[expect(clippy::large_enum_variant)]
#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
#[derive(Clone, Debug, Serialize)]
#[non_exhaustive]
#[serde(untagged)]
pub enum DeserializedEvent {
Success(Event),
Unknown(UnknownEvent),
}

#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
#[derive(Clone, Debug, Deserialize, Serialize)]
#[non_exhaustive]
pub struct UnknownEvent {
#[serde(rename = "t")]
pub ty: String,
#[serde(rename = "d")]
#[cfg_attr(feature = "typesize", typesize(with = raw_value_len))]
pub data: Box<RawValue>,
}

#[cfg(feature = "typesize")]
fn raw_value_len(val: &RawValue) -> usize {
val.get().len()
}

// Manual impl needed to emulate integer enum tags
impl<'de> Deserialize<'de> for GatewayEvent {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> StdResult<Self, D::Error> {
Expand All @@ -968,21 +992,27 @@ impl<'de> Deserialize<'de> for GatewayEvent {

Self::Dispatch {
seq: raw.seq.ok_or_else(|| DeError::missing_field("s"))?,
event: Vec::new(),
event: {
Box::new(match Event::deserialize(raw.data) {
Ok(event) => DeserializedEvent::Success(event),
Err(_) => DeserializedEvent::Unknown(
UnknownEvent::deserialize(raw.data).map_err(DeError::custom)?,
),
})
},
}
},
Opcode::Heartbeat => Self::Heartbeat,
Opcode::InvalidSession => Self::InvalidateSession(
serde_json::from_str(raw.data.get()).map_err(DeError::custom)?,
),
Opcode::InvalidSession => {
Self::InvalidateSession(bool::deserialize(raw.data).map_err(DeError::custom)?)
},
Opcode::Hello => {
#[derive(Deserialize)]
struct HelloPayload {
heartbeat_interval: u64,
}

let inner: HelloPayload =
serde_json::from_str(raw.data.get()).map_err(DeError::custom)?;
let inner = HelloPayload::deserialize(raw.data).map_err(DeError::custom)?;

Self::Hello(inner.heartbeat_interval)
},
Expand Down
Loading