Skip to content

Commit

Permalink
Support zstd (de)compression (#322)
Browse files Browse the repository at this point in the history
* Support zstd compression

* fix

* Give zstd highest priority for `accept-encoding`

* fix typo

* Use `.zst` extension for files precompressed with zstd

* don't pull in two `zstd`s
  • Loading branch information
davidpdrsn authored Feb 24, 2023
1 parent 987f5c9 commit 255c28e
Show file tree
Hide file tree
Showing 16 changed files with 324 additions and 21 deletions.
2 changes: 1 addition & 1 deletion tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Added

- None.
- **compression, decompression:** Support `zstd` compression

## Changed

Expand Down
7 changes: 5 additions & 2 deletions tower-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ tower = { version = "0.4.10", features = ["buffer", "util", "retry", "make", "ti
tracing-subscriber = "0.3"
uuid = { version = "1.0", features = ["v4"] }
serde_json = "1.0"
zstd = "0.11"

[features]
default = []
Expand Down Expand Up @@ -103,13 +104,15 @@ validate-request = ["mime"]

compression-br = ["async-compression/brotli", "tokio-util", "tokio"]
compression-deflate = ["async-compression/zlib", "tokio-util", "tokio"]
compression-full = ["compression-br", "compression-deflate", "compression-gzip"]
compression-full = ["compression-br", "compression-deflate", "compression-gzip", "compression-zstd"]
compression-gzip = ["async-compression/gzip", "tokio-util", "tokio"]
compression-zstd = ["async-compression/zstd", "tokio-util", "tokio"]

decompression-br = ["async-compression/brotli", "tokio-util", "tokio"]
decompression-deflate = ["async-compression/zlib", "tokio-util", "tokio"]
decompression-full = ["decompression-br", "decompression-deflate", "decompression-gzip"]
decompression-full = ["decompression-br", "decompression-deflate", "decompression-gzip", "decompression-zstd"]
decompression-gzip = ["async-compression/gzip", "tokio-util", "tokio"]
decompression-zstd = ["async-compression/zstd", "tokio-util", "tokio"]

[package.metadata.docs.rs]
all-features = true
Expand Down
12 changes: 8 additions & 4 deletions tower-http/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ pub trait ServiceBuilderExt<L>: crate::sealed::Sealed<L> + Sized {
#[cfg(any(
feature = "compression-br",
feature = "compression-deflate",
feature = "compression-gzip"
feature = "compression-gzip",
feature = "compression-zstd",
))]
fn compression(self) -> ServiceBuilder<Stack<crate::compression::CompressionLayer, L>>;

Expand All @@ -108,7 +109,8 @@ pub trait ServiceBuilderExt<L>: crate::sealed::Sealed<L> + Sized {
#[cfg(any(
feature = "decompression-br",
feature = "decompression-deflate",
feature = "decompression-gzip"
feature = "decompression-gzip",
feature = "decompression-zstd",
))]
fn decompression(self) -> ServiceBuilder<Stack<crate::decompression::DecompressionLayer, L>>;

Expand Down Expand Up @@ -405,7 +407,8 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
#[cfg(any(
feature = "compression-br",
feature = "compression-deflate",
feature = "compression-gzip"
feature = "compression-gzip",
feature = "compression-zstd",
))]
fn compression(self) -> ServiceBuilder<Stack<crate::compression::CompressionLayer, L>> {
self.layer(crate::compression::CompressionLayer::new())
Expand All @@ -414,7 +417,8 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
#[cfg(any(
feature = "decompression-br",
feature = "decompression-deflate",
feature = "decompression-gzip"
feature = "decompression-gzip",
feature = "decompression-zstd",
))]
fn decompression(self) -> ServiceBuilder<Stack<crate::decompression::DecompressionLayer, L>> {
self.layer(crate::decompression::DecompressionLayer::new())
Expand Down
55 changes: 55 additions & 0 deletions tower-http/src/compression/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use async_compression::tokio::bufread::BrotliEncoder;
use async_compression::tokio::bufread::GzipEncoder;
#[cfg(feature = "compression-deflate")]
use async_compression::tokio::bufread::ZlibEncoder;
#[cfg(feature = "compression-zstd")]
use async_compression::tokio::bufread::ZstdEncoder;
use bytes::{Buf, Bytes};
use futures_util::ready;
use http::HeaderMap;
Expand Down Expand Up @@ -55,6 +57,8 @@ where
BodyInner::Deflate { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
#[cfg(feature = "compression-br")]
BodyInner::Brotli { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
#[cfg(feature = "compression-zstd")]
BodyInner::Zstd { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(),
BodyInner::Identity { inner } => inner,
}
}
Expand All @@ -68,6 +72,8 @@ where
BodyInner::Deflate { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
#[cfg(feature = "compression-br")]
BodyInner::Brotli { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
#[cfg(feature = "compression-zstd")]
BodyInner::Zstd { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(),
BodyInner::Identity { inner } => inner,
}
}
Expand Down Expand Up @@ -99,6 +105,14 @@ where
.get_pin_mut()
.get_pin_mut()
.get_pin_mut(),
#[cfg(feature = "compression-zstd")]
BodyInnerProj::Zstd { inner } => inner
.project()
.read
.get_pin_mut()
.get_pin_mut()
.get_pin_mut()
.get_pin_mut(),
BodyInnerProj::Identity { inner } => inner,
}
}
Expand Down Expand Up @@ -127,6 +141,13 @@ where
.into_inner()
.into_inner()
.into_inner(),
#[cfg(feature = "compression-zstd")]
BodyInner::Zstd { inner } => inner
.read
.into_inner()
.into_inner()
.into_inner()
.into_inner(),
BodyInner::Identity { inner } => inner,
}
}
Expand All @@ -141,6 +162,9 @@ type DeflateBody<B> = WrapBody<ZlibEncoder<B>>;
#[cfg(feature = "compression-br")]
type BrotliBody<B> = WrapBody<BrotliEncoder<B>>;

#[cfg(feature = "compression-zstd")]
type ZstdBody<B> = WrapBody<ZstdEncoder<B>>;

pin_project_cfg! {
#[project = BodyInnerProj]
pub(crate) enum BodyInner<B>
Expand All @@ -162,6 +186,11 @@ pin_project_cfg! {
#[pin]
inner: BrotliBody<B>,
},
#[cfg(feature = "compression-zstd")]
Zstd {
#[pin]
inner: ZstdBody<B>,
},
Identity {
#[pin]
inner: B,
Expand All @@ -185,6 +214,11 @@ impl<B: Body> BodyInner<B> {
Self::Brotli { inner }
}

#[cfg(feature = "compression-zstd")]
pub(crate) fn zstd(inner: WrapBody<ZstdEncoder<B>>) -> Self {
Self::Zstd { inner }
}

pub(crate) fn identity(inner: B) -> Self {
Self::Identity { inner }
}
Expand All @@ -209,6 +243,8 @@ where
BodyInnerProj::Deflate { inner } => inner.poll_data(cx),
#[cfg(feature = "compression-br")]
BodyInnerProj::Brotli { inner } => inner.poll_data(cx),
#[cfg(feature = "compression-zstd")]
BodyInnerProj::Zstd { inner } => inner.poll_data(cx),
BodyInnerProj::Identity { inner } => match ready!(inner.poll_data(cx)) {
Some(Ok(mut buf)) => {
let bytes = buf.copy_to_bytes(buf.remaining());
Expand All @@ -231,6 +267,8 @@ where
BodyInnerProj::Deflate { inner } => inner.poll_trailers(cx),
#[cfg(feature = "compression-br")]
BodyInnerProj::Brotli { inner } => inner.poll_trailers(cx),
#[cfg(feature = "compression-zstd")]
BodyInnerProj::Zstd { inner } => inner.poll_trailers(cx),
BodyInnerProj::Identity { inner } => inner.poll_trailers(cx).map_err(Into::into),
}
}
Expand Down Expand Up @@ -286,3 +324,20 @@ where
pinned.get_pin_mut()
}
}

#[cfg(feature = "compression-zstd")]
impl<B> DecorateAsyncRead for ZstdEncoder<B>
where
B: Body,
{
type Input = AsyncReadBody<B>;
type Output = ZstdEncoder<Self::Input>;

fn apply(input: Self::Input) -> Self::Output {
ZstdEncoder::new(input)
}

fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
pinned.get_pin_mut()
}
}
2 changes: 2 additions & 0 deletions tower-http/src/compression/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ where
(_, Encoding::Deflate) => CompressionBody::new(BodyInner::deflate(WrapBody::new(body))),
#[cfg(feature = "compression-br")]
(_, Encoding::Brotli) => CompressionBody::new(BodyInner::brotli(WrapBody::new(body))),
#[cfg(feature = "compression-zstd")]
(_, Encoding::Zstd) => CompressionBody::new(BodyInner::zstd(WrapBody::new(body))),
#[cfg(feature = "fs")]
(true, _) => {
// This should never happen because the `AcceptEncoding` struct which is used to determine
Expand Down
15 changes: 15 additions & 0 deletions tower-http/src/compression/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ impl CompressionLayer {
self
}

/// Sets whether to enable the Zstd encoding.
#[cfg(feature = "compression-zstd")]
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}

/// Disables the gzip encoding.
///
/// This method is available even if the `gzip` crate feature is disabled.
Expand All @@ -81,6 +88,14 @@ impl CompressionLayer {
self
}

/// Disables the Zstd encoding.
///
/// This method is available even if the `zstd` crate feature is disabled.
pub fn no_zstd(mut self) -> Self {
self.accept.set_zstd(false);
self
}

/// Replace the current compression predicate.
///
/// See [`Compression::compress_when`] for more details.
Expand Down
30 changes: 29 additions & 1 deletion tower-http/src/compression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ mod tests {
}

#[tokio::test]
async fn works() {
async fn gzip_works() {
let svc = service_fn(handle);
let mut svc = Compression::new(svc).compress_when(Always);

Expand Down Expand Up @@ -141,6 +141,34 @@ mod tests {
assert_eq!(decompressed, "Hello, World!");
}

#[tokio::test]
async fn zstd_works() {
let svc = service_fn(handle);
let mut svc = Compression::new(svc).compress_when(Always);

// call the service
let req = Request::builder()
.header("accept-encoding", "zstd")
.body(Body::empty())
.unwrap();
let res = svc.ready().await.unwrap().call(req).await.unwrap();

// read the compressed body
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let compressed_data = data.freeze().to_vec();

// decompress the body
let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
let decompressed = String::from_utf8(decompressed).unwrap();

assert_eq!(decompressed, "Hello, World!");
}

#[allow(dead_code)]
async fn is_compatible_with_hyper() {
let svc = service_fn(handle);
Expand Down
15 changes: 15 additions & 0 deletions tower-http/src/compression/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ impl<S, P> Compression<S, P> {
self
}

/// Sets whether to enable the Zstd encoding.
#[cfg(feature = "compression-zstd")]
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}

/// Disables the gzip encoding.
///
/// This method is available even if the `gzip` crate feature is disabled.
Expand All @@ -85,6 +92,14 @@ impl<S, P> Compression<S, P> {
self
}

/// Disables the Zstd encoding.
///
/// This method is available even if the `zstd` crate feature is disabled.
pub fn no_zstd(mut self) -> Self {
self.accept.set_zstd(false);
self
}

/// Replace the current compression predicate.
///
/// Predicates are used to determine whether a response should be compressed or not.
Expand Down
45 changes: 36 additions & 9 deletions tower-http/src/compression_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,29 @@ pub(crate) struct AcceptEncoding {
pub(crate) gzip: bool,
pub(crate) deflate: bool,
pub(crate) br: bool,
pub(crate) zstd: bool,
}

impl AcceptEncoding {
#[allow(dead_code)]
pub(crate) fn to_header_value(self) -> Option<HeaderValue> {
let accept = match (self.gzip(), self.deflate(), self.br()) {
(true, true, true) => "gzip,deflate,br",
(true, true, false) => "gzip,deflate",
(true, false, true) => "gzip,br",
(true, false, false) => "gzip",
(false, true, true) => "deflate,br",
(false, true, false) => "deflate",
(false, false, true) => "br",
(false, false, false) => return None,
let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) {
(true, true, true, false) => "gzip,deflate,br",
(true, true, false, false) => "gzip,deflate",
(true, false, true, false) => "gzip,br",
(true, false, false, false) => "gzip",
(false, true, true, false) => "deflate,br",
(false, true, false, false) => "deflate",
(false, false, true, false) => "br",
(true, true, true, true) => "zstd,gzip,deflate,br",
(true, true, false, true) => "zstd,gzip,deflate",
(true, false, true, true) => "zstd,gzip,br",
(true, false, false, true) => "zstd,gzip",
(false, true, true, true) => "zstd,deflate,br",
(false, true, false, true) => "zstd,deflate",
(false, false, true, true) => "zstd,br",
(false, false, false, true) => "zstd",
(false, false, false, false) => return None,
};
Some(HeaderValue::from_static(accept))
}
Expand All @@ -52,6 +61,11 @@ impl AcceptEncoding {
pub(crate) fn set_br(&mut self, enable: bool) {
self.br = enable;
}

#[allow(dead_code)]
pub(crate) fn set_zstd(&mut self, enable: bool) {
self.zstd = enable;
}
}

impl SupportedEncodings for AcceptEncoding {
Expand Down Expand Up @@ -90,6 +104,18 @@ impl SupportedEncodings for AcceptEncoding {
false
}
}

#[allow(dead_code)]
fn zstd(&self) -> bool {
#[cfg(any(feature = "decompression-zstd", feature = "compression-zstd"))]
{
self.zstd
}
#[cfg(not(any(feature = "decompression-zstd", feature = "compression-zstd")))]
{
false
}
}
}

impl Default for AcceptEncoding {
Expand All @@ -98,6 +124,7 @@ impl Default for AcceptEncoding {
gzip: true,
deflate: true,
br: true,
zstd: true,
}
}
}
Expand Down
Loading

0 comments on commit 255c28e

Please sign in to comment.