Skip to content

Commit 5fe2801

Browse files
committed
Avoid temporarily deserializing gateway messages to a serde_json::Map
1 parent 3923997 commit 5fe2801

File tree

4 files changed

+54
-56
lines changed

4 files changed

+54
-56
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ rust-version = "1.82"
2929
[dependencies]
3030
# Required dependencies
3131
bitflags = "2.4.2"
32-
serde_json = "1.0.108"
32+
serde_json = { version = "1.0.108", features = ["raw_value"] }
3333
async-trait = "0.1.74"
3434
tracing = { version = "0.1.40", features = ["log"] }
3535
serde = { version = "1.0.192", features = ["derive", "rc"] }

src/gateway/sharding/mod.rs

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ use std::time::{Duration as StdDuration, Instant};
4747
#[cfg(any(feature = "transport_compression_zlib", feature = "transport_compression_zstd"))]
4848
use aformat::aformat_into;
4949
use aformat::{aformat, ArrayString, CapStr};
50-
use serde::Deserialize;
5150
use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
5251
use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
5352
#[cfg(feature = "tracing_instrument")]
@@ -319,18 +318,13 @@ impl Shard {
319318
}
320319

321320
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
322-
fn handle_gateway_dispatch(
323-
&mut self,
324-
seq: u64,
325-
event: JsonMap,
326-
original_str: &str,
327-
) -> Result<Event> {
321+
fn handle_gateway_dispatch(&mut self, seq: u64, event: &[u8]) -> Result<Event> {
328322
if seq > self.seq + 1 {
329323
warn!("[{:?}] Sequence off; them: {}, us: {}", self.shard_info, seq, self.seq);
330324
}
331325

332326
self.seq = seq;
333-
let event = deserialize_and_log_event(event, original_str)?;
327+
let event = deserialize_and_log_event(event)?;
334328

335329
match &event {
336330
Event::Ready(ready) => {
@@ -453,11 +447,8 @@ impl Shard {
453447
match event {
454448
Ok(GatewayEvent::Dispatch {
455449
seq,
456-
data,
457-
original_str,
458-
}) => self
459-
.handle_gateway_dispatch(seq, data, &original_str)
460-
.map(|e| Some(ShardAction::Dispatch(e))),
450+
event,
451+
}) => self.handle_gateway_dispatch(seq, &event).map(|e| Some(ShardAction::Dispatch(e))),
461452
Ok(GatewayEvent::Heartbeat(..)) => {
462453
info!("[{:?}] Received shard heartbeat", self.shard_info);
463454

@@ -749,8 +740,8 @@ async fn connect(base_url: &str, compression: TransportCompression) -> Result<Ws
749740
WsClient::connect(url, compression).await
750741
}
751742

752-
fn deserialize_and_log_event(map: JsonMap, original_str: &str) -> Result<Event> {
753-
Event::deserialize(Value::Object(map)).map_err(|err| {
743+
fn deserialize_and_log_event(event: &[u8]) -> Result<Event> {
744+
serde_json::from_slice(event).map_err(|err| {
754745
let err = serde::de::Error::custom(err);
755746
let err_dbg = format!("{err:?}");
756747
if let Some((variant_name, _)) =
@@ -760,7 +751,9 @@ fn deserialize_and_log_event(map: JsonMap, original_str: &str) -> Result<Event>
760751
} else {
761752
warn!("Err deserializing text: {err_dbg}");
762753
}
763-
debug!("Failing text: {original_str}");
754+
755+
let event_str = String::from_utf8_lossy(event);
756+
debug!("Failing event data: {event_str}");
764757
Error::Json(err)
765758
})
766759
}

src/gateway/ws.rs

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::borrow::Cow;
21
use std::env::consts;
32
use std::io::Read;
43
use std::time::SystemTime;
@@ -7,7 +6,6 @@ use flate2::read::ZlibDecoder;
76
#[cfg(feature = "transport_compression_zlib")]
87
use flate2::Decompress as ZlibInflater;
98
use futures::{SinkExt, StreamExt};
10-
use small_fixed_array::FixedString;
119
use tokio::net::TcpStream;
1210
use tokio::time::{timeout, Duration};
1311
use tokio_tungstenite::tungstenite::protocol::{CloseFrame, WebSocketConfig};
@@ -254,35 +252,21 @@ impl WsClient {
254252
};
255253

256254
let json_bytes = match message {
257-
Message::Text(payload) => Cow::Owned(payload.as_bytes().to_vec()),
258-
Message::Binary(bytes) => {
259-
let Some(decompressed) = self.compression.inflate(&bytes)? else {
260-
return Ok(None);
261-
};
262-
263-
Cow::Borrowed(decompressed)
255+
Message::Text(ref payload) => payload.as_bytes(),
256+
Message::Binary(ref bytes) => match self.compression.inflate(bytes)? {
257+
Some(decompressed) => decompressed,
258+
None => return Ok(None),
264259
},
265260
Message::Close(Some(frame)) => {
266261
return Err(Error::Gateway(GatewayError::Closed(Some(frame))));
267262
},
268263
_ => return Ok(None),
269264
};
270265

271-
// TODO: Use `String::from_utf8_lossy_owned` when stable.
272-
let json_str = || String::from_utf8_lossy(&json_bytes);
273-
match serde_json::from_slice(&json_bytes) {
274-
Ok(mut event) => {
275-
if let GatewayEvent::Dispatch {
276-
original_str, ..
277-
} = &mut event
278-
{
279-
*original_str = FixedString::from_string_trunc(json_str().into_owned());
280-
}
281-
282-
Ok(Some(event))
283-
},
266+
match serde_json::from_slice(json_bytes) {
267+
Ok(event) => Ok(Some(event)),
284268
Err(err) => {
285-
debug!("Failing text: {}", json_str());
269+
debug!("Failing text: {}", String::from_utf8_lossy(json_bytes));
286270
Err(Error::Json(err))
287271
},
288272
}

src/model/event.rs

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
66
use serde::de::Error as DeError;
77
use serde::Serialize;
8+
use serde_json::value::RawValue;
89
use strum::{EnumCount, IntoStaticStr, VariantNames};
910

1011
use crate::constants::Opcode;
1112
use crate::internal::utils::lending_for_each;
1213
use crate::model::prelude::*;
13-
use crate::model::utils::remove_from_map;
1414

1515
/// Requires no gateway intents.
1616
///
@@ -933,9 +933,7 @@ pub enum GatewayEvent {
933933
Dispatch {
934934
seq: u64,
935935
// Avoid deserialising straight away to handle errors and get access to `seq`.
936-
data: JsonMap,
937-
// Used for debugging, if the data cannot be deserialised.
938-
original_str: FixedString,
936+
event: Vec<u8>,
939937
},
940938
Heartbeat(#[deprecated = "always 0 because it is never provided by the gateway"] u64),
941939
Reconnect,
@@ -948,30 +946,53 @@ pub enum GatewayEvent {
948946
// Manual impl needed to emulate integer enum tags
949947
impl<'de> Deserialize<'de> for GatewayEvent {
950948
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> StdResult<Self, D::Error> {
951-
let mut map = JsonMap::deserialize(deserializer)?;
952-
953-
Ok(match remove_from_map(&mut map, "op")? {
954-
Opcode::Dispatch => {
955-
Self::Dispatch {
956-
seq: remove_from_map(&mut map, "s")?,
957-
// Filled in in recv_event
958-
original_str: FixedString::new(),
959-
data: map,
960-
}
949+
#[derive(Debug, Clone, Deserialize)]
950+
struct GatewayEventRaw<'a> {
951+
op: Opcode,
952+
#[serde(rename = "s")]
953+
seq: Option<u64>,
954+
#[serde(rename = "d")]
955+
data: &'a RawValue,
956+
#[serde(rename = "t")]
957+
ty: Option<&'a str>,
958+
}
959+
960+
#[derive(Debug, Clone, Serialize)]
961+
struct UndeserializedEvent<'a> {
962+
#[serde(rename = "d")]
963+
data: &'a RawValue,
964+
#[serde(rename = "t")]
965+
ty: &'a str,
966+
}
967+
968+
let raw: GatewayEventRaw<'_> = Deserialize::deserialize(deserializer)?;
969+
970+
Ok(match raw.op {
971+
Opcode::Dispatch => Self::Dispatch {
972+
seq: raw.seq.ok_or_else(|| DeError::custom("missing seq"))?,
973+
event: serde_json::to_vec(&UndeserializedEvent {
974+
data: raw.data,
975+
ty: raw.ty.ok_or_else(|| DeError::custom("missing t"))?,
976+
})
977+
.map_err(DeError::custom)?,
961978
},
962979
Opcode::Heartbeat => {
963980
// Placeholder value. Discord expects the last Dispatch
964981
// sequence number and doesn't send it with the heartbeat.
965982
Self::Heartbeat(0)
966983
},
967-
Opcode::InvalidSession => Self::InvalidateSession(remove_from_map(&mut map, "d")?),
984+
Opcode::InvalidSession => Self::InvalidateSession(
985+
serde_json::from_str(raw.data.get()).map_err(DeError::custom)?,
986+
),
968987
Opcode::Hello => {
969988
#[derive(Deserialize)]
970989
struct HelloPayload {
971990
heartbeat_interval: u64,
972991
}
973992

974-
let inner: HelloPayload = remove_from_map(&mut map, "d")?;
993+
let inner: HelloPayload =
994+
serde_json::from_str(raw.data.get()).map_err(DeError::custom)?;
995+
975996
Self::Hello(inner.heartbeat_interval)
976997
},
977998
Opcode::Reconnect => Self::Reconnect,

0 commit comments

Comments
 (0)