Skip to content

Commit

Permalink
Clean up EncodeBody API (#1924)
Browse files Browse the repository at this point in the history
* tonic: remove unnecessary bounds in codec::encode

* tonic: reduce unnecessary visibility for EncodedBytes

* tonic: use EncodedBytes directly in EncodeBody

* tonic: handle fuse() inside EncodedBytes

* tonic: fold encode_server() into EncodeBody::new_server()

* tonic: move mapping responsibility to caller

* tonic: fold encode_client() into EncodeBody::new_client()
  • Loading branch information
djc authored Sep 26, 2024
1 parent e6782fe commit 3c900eb
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 80 deletions.
2 changes: 1 addition & 1 deletion tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ percent-encoding = "2.1"
pin-project = "1.0.11"
tower-layer = "0.3"
tower-service = "0.3"
tokio-stream = {version = "0.1", default-features = false}
tokio-stream = {version = "0.1.16", default-features = false}

# prost
prost = {version = "0.13", default-features = false, features = ["std"], optional = true}
Expand Down
7 changes: 4 additions & 3 deletions tonic/src/client/grpc.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings};
use crate::codec::EncodeBody;
use crate::metadata::GRPC_CONTENT_TYPE;
use crate::{
body::BoxBody,
client::GrpcService,
codec::{encode_client, Codec, Decoder, Streaming},
codec::{Codec, Decoder, Streaming},
request::SanitizeHeaders,
Code, Request, Response, Status,
};
Expand Down Expand Up @@ -295,9 +296,9 @@ impl<T> Grpc<T> {
{
let request = request
.map(|s| {
encode_client(
EncodeBody::new_client(
codec.encoder(),
s,
s.map(Ok),
self.config.send_compression_encodings,
self.config.max_encoding_message_size,
)
Expand Down
111 changes: 42 additions & 69 deletions tonic/src/codec/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,53 +11,7 @@ use std::{
pin::Pin,
task::{ready, Context, Poll},
};
use tokio_stream::{Stream, StreamExt};

/// Turns a stream of grpc results (message or error status) into [EncodeBody] which is used by grpc
/// servers for turning the messages into http frames for sending over the network.
pub fn encode_server<T, U>(
encoder: T,
source: U,
compression_encoding: Option<CompressionEncoding>,
compression_override: SingleMessageCompressionOverride,
max_message_size: Option<usize>,
) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
let stream = EncodedBytes::new(
encoder,
source.fuse(),
compression_encoding,
compression_override,
max_message_size,
);

EncodeBody::new_server(stream)
}

/// Turns a stream of grpc messages into [EncodeBody] which is used by grpc clients for
/// turning the messages into http frames for sending over the network.
pub fn encode_client<T, U>(
encoder: T,
source: U,
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
where
T: Encoder<Error = Status>,
U: Stream<Item = T::Item>,
{
let stream = EncodedBytes::new(
encoder,
source.fuse().map(Ok),
compression_encoding,
SingleMessageCompressionOverride::default(),
max_message_size,
);
EncodeBody::new_client(stream)
}
use tokio_stream::{adapters::Fuse, Stream, StreamExt};

/// Combinator for efficient encoding of messages into reasonably sized buffers.
/// EncodedBytes encodes ready messages from its delegate stream into a BytesMut,
Expand All @@ -66,13 +20,9 @@ where
/// * The encoded buffer surpasses YIELD_THRESHOLD.
#[pin_project(project = EncodedBytesProj)]
#[derive(Debug)]
pub(crate) struct EncodedBytes<T, U>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
struct EncodedBytes<T, U> {
#[pin]
source: U,
source: Fuse<U>,
encoder: T,
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
Expand All @@ -81,12 +31,7 @@ where
error: Option<Status>,
}

impl<T, U> EncodedBytes<T, U>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
// `source` should be fused stream.
impl<T: Encoder, U: Stream> EncodedBytes<T, U> {
fn new(
encoder: T,
source: U,
Expand All @@ -111,7 +56,7 @@ where
};

Self {
source,
source: source.fuse(),
encoder,
compression_encoding,
max_message_size,
Expand Down Expand Up @@ -270,9 +215,9 @@ enum Role {
/// A specialized implementation of [Body] for encoding [Result<Bytes, Status>].
#[pin_project]
#[derive(Debug)]
pub struct EncodeBody<S> {
pub struct EncodeBody<T, U> {
#[pin]
inner: S,
inner: EncodedBytes<T, U>,
state: EncodeState,
}

Expand All @@ -283,10 +228,23 @@ struct EncodeState {
is_end_stream: bool,
}

impl<S> EncodeBody<S> {
fn new_client(inner: S) -> Self {
impl<T: Encoder, U: Stream> EncodeBody<T, U> {
/// Turns a stream of grpc messages into [EncodeBody] which is used by grpc clients for
/// turning the messages into http frames for sending over the network.
pub fn new_client(
encoder: T,
source: U,
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
) -> Self {
Self {
inner,
inner: EncodedBytes::new(
encoder,
source,
compression_encoding,
SingleMessageCompressionOverride::default(),
max_message_size,
),
state: EncodeState {
error: None,
role: Role::Client,
Expand All @@ -295,9 +253,23 @@ impl<S> EncodeBody<S> {
}
}

fn new_server(inner: S) -> Self {
/// Turns a stream of grpc results (message or error status) into [EncodeBody] which is used by grpc
/// servers for turning the messages into http frames for sending over the network.
pub fn new_server(
encoder: T,
source: U,
compression_encoding: Option<CompressionEncoding>,
compression_override: SingleMessageCompressionOverride,
max_message_size: Option<usize>,
) -> Self {
Self {
inner,
inner: EncodedBytes::new(
encoder,
source,
compression_encoding,
compression_override,
max_message_size,
),
state: EncodeState {
error: None,
role: Role::Server,
Expand Down Expand Up @@ -328,9 +300,10 @@ impl EncodeState {
}
}

impl<S> Body for EncodeBody<S>
impl<T, U> Body for EncodeBody<T, U>
where
S: Stream<Item = Result<Bytes, Status>>,
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
type Data = Bytes;
type Error = Status;
Expand Down
2 changes: 1 addition & 1 deletion tonic/src/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::io;
pub use self::buffer::{DecodeBuf, EncodeBuf};
pub use self::compression::{CompressionEncoding, EnabledCompressionEncodings};
pub use self::decode::Streaming;
pub use self::encode::{encode_client, encode_server, EncodeBody};
pub use self::encode::EncodeBody;
#[cfg(feature = "prost")]
pub use self::prost::ProstCodec;

Expand Down
10 changes: 6 additions & 4 deletions tonic/src/codec/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ fn from_decode_error(error: prost::DecodeError) -> crate::Status {
mod tests {
use crate::codec::compression::SingleMessageCompressionOverride;
use crate::codec::{
encode_server, DecodeBuf, Decoder, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
DecodeBuf, Decoder, EncodeBody, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
};
use crate::Status;
use bytes::{Buf, BufMut, BytesMut};
Expand Down Expand Up @@ -228,7 +228,7 @@ mod tests {
let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
let source = tokio_stream::iter(messages);

let mut body = pin!(encode_server(
let mut body = pin!(EncodeBody::new_server(
encoder,
source,
None,
Expand All @@ -250,7 +250,7 @@ mod tests {
let messages = std::iter::once(Ok::<_, Status>(msg));
let source = tokio_stream::iter(messages);

let mut body = pin!(encode_server(
let mut body = pin!(EncodeBody::new_server(
encoder,
source,
None,
Expand Down Expand Up @@ -278,14 +278,16 @@ mod tests {
#[cfg(not(target_family = "windows"))]
#[tokio::test]
async fn encode_too_big() {
use crate::codec::EncodeBody;

let encoder = MockEncoder::default();

let msg = vec![0u8; u32::MAX as usize + 1];

let messages = std::iter::once(Ok::<_, Status>(msg));
let source = tokio_stream::iter(messages);

let mut body = pin!(encode_server(
let mut body = pin!(EncodeBody::new_server(
encoder,
source,
None,
Expand Down
5 changes: 3 additions & 2 deletions tonic/src/server/grpc.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::codec::compression::{
CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
};
use crate::codec::EncodeBody;
use crate::metadata::GRPC_CONTENT_TYPE;
use crate::{
body::BoxBody,
codec::{encode_server, Codec, Streaming},
codec::{Codec, Streaming},
server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
Request, Status,
};
Expand Down Expand Up @@ -447,7 +448,7 @@ where
);
}

let body = encode_server(
let body = EncodeBody::new_server(
self.codec.encoder(),
body,
accept_encoding,
Expand Down

0 comments on commit 3c900eb

Please sign in to comment.