Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ rust-version = "1.82"
[dependencies]
# Required dependencies
bitflags = "2.4.2"
serde_json = "1.0.108"
serde_json = { version = "1.0.108", features = ["raw_value"] }
async-trait = "0.1.74"
tracing = { version = "0.1.40", features = ["log"] }
serde = { version = "1.0.192", features = ["derive", "rc"] }
Expand Down
26 changes: 9 additions & 17 deletions src/gateway/sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ use std::time::{Duration as StdDuration, Instant};
#[cfg(any(feature = "transport_compression_zlib", feature = "transport_compression_zstd"))]
use aformat::aformat_into;
use aformat::{aformat, ArrayString, CapStr};
use serde::Deserialize;
use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
#[cfg(feature = "tracing_instrument")]
Expand Down Expand Up @@ -319,18 +318,13 @@ impl Shard {
}

#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
fn handle_gateway_dispatch(
&mut self,
seq: u64,
event: JsonMap,
original_str: &str,
) -> Result<Event> {
fn handle_gateway_dispatch(&mut self, seq: u64, event: &[u8]) -> Result<Event> {
if seq > self.seq + 1 {
warn!("[{:?}] Sequence off; them: {}, us: {}", self.shard_info, seq, self.seq);
}

self.seq = seq;
let event = deserialize_and_log_event(event, original_str)?;
let event = deserialize_and_log_event(event)?;

match &event {
Event::Ready(ready) => {
Expand Down Expand Up @@ -453,11 +447,8 @@ impl Shard {
match event {
Ok(GatewayEvent::Dispatch {
seq,
data,
original_str,
}) => self
.handle_gateway_dispatch(seq, data, &original_str)
.map(|e| Some(ShardAction::Dispatch(e))),
event,
}) => self.handle_gateway_dispatch(seq, &event).map(|e| Some(ShardAction::Dispatch(e))),
Ok(GatewayEvent::Heartbeat) => {
info!("[{:?}] Received shard heartbeat", self.shard_info);

Expand Down Expand Up @@ -749,9 +740,8 @@ async fn connect(base_url: &str, compression: TransportCompression) -> Result<Ws
WsClient::connect(url, compression).await
}

fn deserialize_and_log_event(map: JsonMap, original_str: &str) -> Result<Event> {
Event::deserialize(Value::Object(map)).map_err(|err| {
let err = serde::de::Error::custom(err);
fn deserialize_and_log_event(event: &[u8]) -> Result<Event> {
serde_json::from_slice(event).map_err(|err| {
let err_dbg = format!("{err:?}");
if let Some((variant_name, _)) =
err_dbg.strip_prefix(r#"Error("unknown variant `"#).and_then(|s| s.split_once('`'))
Expand All @@ -760,7 +750,9 @@ fn deserialize_and_log_event(map: JsonMap, original_str: &str) -> Result<Event>
} else {
warn!("Err deserializing text: {err_dbg}");
}
debug!("Failing text: {original_str}");

let event_str = String::from_utf8_lossy(event);
debug!("Failing event data: {event_str}");
Error::Json(err)
})
}
Expand Down
25 changes: 9 additions & 16 deletions src/gateway/ws.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::borrow::Cow;
use std::env::consts;
use std::io::Read;
use std::time::SystemTime;
Expand All @@ -7,7 +6,6 @@ use flate2::read::ZlibDecoder;
#[cfg(feature = "transport_compression_zlib")]
use flate2::Decompress as ZlibInflater;
use futures::{SinkExt, StreamExt};
use small_fixed_array::FixedString;
use tokio::net::TcpStream;
use tokio::time::{timeout, Duration};
use tokio_tungstenite::tungstenite::protocol::{CloseFrame, WebSocketConfig};
Expand Down Expand Up @@ -254,35 +252,30 @@ impl WsClient {
};

let json_bytes = match message {
Message::Text(payload) => Cow::Owned(payload.as_bytes().to_vec()),
Message::Binary(bytes) => {
let Some(decompressed) = self.compression.inflate(&bytes)? else {
return Ok(None);
};

Cow::Borrowed(decompressed)
Message::Text(ref payload) => payload.as_bytes(),
Message::Binary(ref bytes) => match self.compression.inflate(bytes)? {
Some(decompressed) => decompressed,
None => return Ok(None),
},
Message::Close(Some(frame)) => {
return Err(Error::Gateway(GatewayError::Closed(Some(frame))));
},
_ => return Ok(None),
};

// TODO: Use `String::from_utf8_lossy_owned` when stable.
let json_str = || String::from_utf8_lossy(&json_bytes);
match serde_json::from_slice(&json_bytes) {
match serde_json::from_slice(json_bytes) {
Ok(mut event) => {
if let GatewayEvent::Dispatch {
original_str, ..
} = &mut event
ref mut event, ..
} = event
{
*original_str = FixedString::from_string_trunc(json_str().into_owned());
*event = json_bytes.to_vec();
}

Ok(Some(event))
},
Err(err) => {
debug!("Failing text: {}", json_str());
debug!("Failing text: {}", String::from_utf8_lossy(json_bytes));
Err(Error::Json(err))
},
}
Expand Down
42 changes: 29 additions & 13 deletions src/model/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

use serde::de::Error as DeError;
use serde::Serialize;
use serde_json::value::RawValue;
use strum::{EnumCount, IntoStaticStr, VariantNames};

use crate::constants::Opcode;
use crate::internal::utils::lending_for_each;
use crate::model::prelude::*;
use crate::model::utils::remove_from_map;

/// Requires no gateway intents.
///
Expand Down Expand Up @@ -933,9 +933,8 @@ pub enum GatewayEvent {
Dispatch {
seq: u64,
// Avoid deserialising straight away to handle errors and get access to `seq`.
data: JsonMap,
// Used for debugging, if the data cannot be deserialised.
original_str: FixedString,
// This must be filled in with original data by the caller after deserialisation.
event: Vec<u8>,
},
Heartbeat,
Reconnect,
Expand All @@ -948,26 +947,43 @@ pub enum GatewayEvent {
// Manual impl needed to emulate integer enum tags
impl<'de> Deserialize<'de> for GatewayEvent {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> StdResult<Self, D::Error> {
let mut map = JsonMap::deserialize(deserializer)?;

Ok(match remove_from_map(&mut map, "op")? {
#[derive(Debug, Clone, Deserialize)]
struct GatewayEventRaw<'a> {
op: Opcode,
#[serde(rename = "s")]
seq: Option<u64>,
#[serde(rename = "d")]
data: &'a RawValue,
#[serde(rename = "t")]
ty: Option<&'a str>,
}

let raw = GatewayEventRaw::deserialize(deserializer)?;

Ok(match raw.op {
Opcode::Dispatch => {
if raw.ty.is_none() {
return Err(DeError::missing_field("t"));
}

Self::Dispatch {
seq: remove_from_map(&mut map, "s")?,
// Filled in in recv_event
original_str: FixedString::new(),
data: map,
seq: raw.seq.ok_or_else(|| DeError::missing_field("s"))?,
event: Vec::new(),
}
},
Opcode::Heartbeat => Self::Heartbeat,
Opcode::InvalidSession => Self::InvalidateSession(remove_from_map(&mut map, "d")?),
Opcode::InvalidSession => Self::InvalidateSession(
serde_json::from_str(raw.data.get()).map_err(DeError::custom)?,
),
Opcode::Hello => {
#[derive(Deserialize)]
struct HelloPayload {
heartbeat_interval: u64,
}

let inner: HelloPayload = remove_from_map(&mut map, "d")?;
let inner: HelloPayload =
serde_json::from_str(raw.data.get()).map_err(DeError::custom)?;

Self::Hello(inner.heartbeat_interval)
},
Opcode::Reconnect => Self::Reconnect,
Expand Down
Loading