From 255c28ef4925451750e2c3d1ad979e8f86ecea6b Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 24 Feb 2023 13:00:22 +0100 Subject: [PATCH] Support zstd (de)compression (#322) * Support zstd compression * fix * Give zstd highest priority for `accept-encoding` * fix typo * Use `.zst` extension for files precompressed with zstd * don't pull in two `zstd`s --- tower-http/CHANGELOG.md | 2 +- tower-http/Cargo.toml | 7 +- tower-http/src/builder.rs | 12 ++-- tower-http/src/compression/body.rs | 55 ++++++++++++++++ tower-http/src/compression/future.rs | 2 + tower-http/src/compression/layer.rs | 15 +++++ tower-http/src/compression/mod.rs | 30 ++++++++- tower-http/src/compression/service.rs | 15 +++++ tower-http/src/compression_utils.rs | 45 ++++++++++--- tower-http/src/content_encoding.rs | 24 ++++++- tower-http/src/decompression/body.rs | 71 ++++++++++++++++++++- tower-http/src/decompression/future.rs | 5 ++ tower-http/src/decompression/layer.rs | 15 +++++ tower-http/src/decompression/service.rs | 15 +++++ tower-http/src/lib.rs | 10 ++- tower-http/src/services/fs/serve_dir/mod.rs | 22 +++++++ 16 files changed, 324 insertions(+), 21 deletions(-) diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index 17757e18..5e2dec50 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added -- None. +- **compression, decompression:** Support `zstd` compression ## Changed diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index a453918d..b9f93646 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -50,6 +50,7 @@ tower = { version = "0.4.10", features = ["buffer", "util", "retry", "make", "ti tracing-subscriber = "0.3" uuid = { version = "1.0", features = ["v4"] } serde_json = "1.0" +zstd = "0.11" [features] default = [] @@ -103,13 +104,15 @@ validate-request = ["mime"] compression-br = ["async-compression/brotli", "tokio-util", "tokio"] compression-deflate = ["async-compression/zlib", "tokio-util", "tokio"] -compression-full = ["compression-br", "compression-deflate", "compression-gzip"] +compression-full = ["compression-br", "compression-deflate", "compression-gzip", "compression-zstd"] compression-gzip = ["async-compression/gzip", "tokio-util", "tokio"] +compression-zstd = ["async-compression/zstd", "tokio-util", "tokio"] decompression-br = ["async-compression/brotli", "tokio-util", "tokio"] decompression-deflate = ["async-compression/zlib", "tokio-util", "tokio"] -decompression-full = ["decompression-br", "decompression-deflate", "decompression-gzip"] +decompression-full = ["decompression-br", "decompression-deflate", "decompression-gzip", "decompression-zstd"] decompression-gzip = ["async-compression/gzip", "tokio-util", "tokio"] +decompression-zstd = ["async-compression/zstd", "tokio-util", "tokio"] [package.metadata.docs.rs] all-features = true diff --git a/tower-http/src/builder.rs b/tower-http/src/builder.rs index 9dcbfd24..2cb4f94a 100644 --- a/tower-http/src/builder.rs +++ b/tower-http/src/builder.rs @@ -96,7 +96,8 @@ pub trait ServiceBuilderExt: crate::sealed::Sealed + Sized { #[cfg(any( feature = "compression-br", feature = "compression-deflate", - feature = "compression-gzip" + feature = "compression-gzip", + feature = "compression-zstd", ))] fn compression(self) -> ServiceBuilder>; @@ -108,7 +109,8 @@ pub trait ServiceBuilderExt: crate::sealed::Sealed + Sized { #[cfg(any( feature = "decompression-br", feature = "decompression-deflate", - feature = "decompression-gzip" + feature = "decompression-gzip", + feature = "decompression-zstd", ))] fn decompression(self) -> ServiceBuilder>; @@ -405,7 +407,8 @@ impl ServiceBuilderExt for ServiceBuilder { #[cfg(any( feature = "compression-br", feature = "compression-deflate", - feature = "compression-gzip" + feature = "compression-gzip", + feature = "compression-zstd", ))] fn compression(self) -> ServiceBuilder> { self.layer(crate::compression::CompressionLayer::new()) @@ -414,7 +417,8 @@ impl ServiceBuilderExt for ServiceBuilder { #[cfg(any( feature = "decompression-br", feature = "decompression-deflate", - feature = "decompression-gzip" + feature = "decompression-gzip", + feature = "decompression-zstd", ))] fn decompression(self) -> ServiceBuilder> { self.layer(crate::decompression::DecompressionLayer::new()) diff --git a/tower-http/src/compression/body.rs b/tower-http/src/compression/body.rs index 579f7325..b737151b 100644 --- a/tower-http/src/compression/body.rs +++ b/tower-http/src/compression/body.rs @@ -10,6 +10,8 @@ use async_compression::tokio::bufread::BrotliEncoder; use async_compression::tokio::bufread::GzipEncoder; #[cfg(feature = "compression-deflate")] use async_compression::tokio::bufread::ZlibEncoder; +#[cfg(feature = "compression-zstd")] +use async_compression::tokio::bufread::ZstdEncoder; use bytes::{Buf, Bytes}; use futures_util::ready; use http::HeaderMap; @@ -55,6 +57,8 @@ where BodyInner::Deflate { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), #[cfg(feature = "compression-br")] BodyInner::Brotli { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), + #[cfg(feature = "compression-zstd")] + BodyInner::Zstd { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), BodyInner::Identity { inner } => inner, } } @@ -68,6 +72,8 @@ where BodyInner::Deflate { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), #[cfg(feature = "compression-br")] BodyInner::Brotli { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), + #[cfg(feature = "compression-zstd")] + BodyInner::Zstd { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), BodyInner::Identity { inner } => inner, } } @@ -99,6 +105,14 @@ where .get_pin_mut() .get_pin_mut() .get_pin_mut(), + #[cfg(feature = "compression-zstd")] + BodyInnerProj::Zstd { inner } => inner + .project() + .read + .get_pin_mut() + .get_pin_mut() + .get_pin_mut() + .get_pin_mut(), BodyInnerProj::Identity { inner } => inner, } } @@ -127,6 +141,13 @@ where .into_inner() .into_inner() .into_inner(), + #[cfg(feature = "compression-zstd")] + BodyInner::Zstd { inner } => inner + .read + .into_inner() + .into_inner() + .into_inner() + .into_inner(), BodyInner::Identity { inner } => inner, } } @@ -141,6 +162,9 @@ type DeflateBody = WrapBody>; #[cfg(feature = "compression-br")] type BrotliBody = WrapBody>; +#[cfg(feature = "compression-zstd")] +type ZstdBody = WrapBody>; + pin_project_cfg! { #[project = BodyInnerProj] pub(crate) enum BodyInner @@ -162,6 +186,11 @@ pin_project_cfg! { #[pin] inner: BrotliBody, }, + #[cfg(feature = "compression-zstd")] + Zstd { + #[pin] + inner: ZstdBody, + }, Identity { #[pin] inner: B, @@ -185,6 +214,11 @@ impl BodyInner { Self::Brotli { inner } } + #[cfg(feature = "compression-zstd")] + pub(crate) fn zstd(inner: WrapBody>) -> Self { + Self::Zstd { inner } + } + pub(crate) fn identity(inner: B) -> Self { Self::Identity { inner } } @@ -209,6 +243,8 @@ where BodyInnerProj::Deflate { inner } => inner.poll_data(cx), #[cfg(feature = "compression-br")] BodyInnerProj::Brotli { inner } => inner.poll_data(cx), + #[cfg(feature = "compression-zstd")] + BodyInnerProj::Zstd { inner } => inner.poll_data(cx), BodyInnerProj::Identity { inner } => match ready!(inner.poll_data(cx)) { Some(Ok(mut buf)) => { let bytes = buf.copy_to_bytes(buf.remaining()); @@ -231,6 +267,8 @@ where BodyInnerProj::Deflate { inner } => inner.poll_trailers(cx), #[cfg(feature = "compression-br")] BodyInnerProj::Brotli { inner } => inner.poll_trailers(cx), + #[cfg(feature = "compression-zstd")] + BodyInnerProj::Zstd { inner } => inner.poll_trailers(cx), BodyInnerProj::Identity { inner } => inner.poll_trailers(cx).map_err(Into::into), } } @@ -286,3 +324,20 @@ where pinned.get_pin_mut() } } + +#[cfg(feature = "compression-zstd")] +impl DecorateAsyncRead for ZstdEncoder +where + B: Body, +{ + type Input = AsyncReadBody; + type Output = ZstdEncoder; + + fn apply(input: Self::Input) -> Self::Output { + ZstdEncoder::new(input) + } + + fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { + pinned.get_pin_mut() + } +} diff --git a/tower-http/src/compression/future.rs b/tower-http/src/compression/future.rs index b70a83db..11f73424 100644 --- a/tower-http/src/compression/future.rs +++ b/tower-http/src/compression/future.rs @@ -60,6 +60,8 @@ where (_, Encoding::Deflate) => CompressionBody::new(BodyInner::deflate(WrapBody::new(body))), #[cfg(feature = "compression-br")] (_, Encoding::Brotli) => CompressionBody::new(BodyInner::brotli(WrapBody::new(body))), + #[cfg(feature = "compression-zstd")] + (_, Encoding::Zstd) => CompressionBody::new(BodyInner::zstd(WrapBody::new(body))), #[cfg(feature = "fs")] (true, _) => { // This should never happen because the `AcceptEncoding` struct which is used to determine diff --git a/tower-http/src/compression/layer.rs b/tower-http/src/compression/layer.rs index 0dd62063..b61fdc43 100644 --- a/tower-http/src/compression/layer.rs +++ b/tower-http/src/compression/layer.rs @@ -57,6 +57,13 @@ impl CompressionLayer { self } + /// Sets whether to enable the Zstd encoding. + #[cfg(feature = "compression-zstd")] + pub fn zstd(mut self, enable: bool) -> Self { + self.accept.set_zstd(enable); + self + } + /// Disables the gzip encoding. /// /// This method is available even if the `gzip` crate feature is disabled. @@ -81,6 +88,14 @@ impl CompressionLayer { self } + /// Disables the Zstd encoding. + /// + /// This method is available even if the `zstd` crate feature is disabled. + pub fn no_zstd(mut self) -> Self { + self.accept.set_zstd(false); + self + } + /// Replace the current compression predicate. /// /// See [`Compression::compress_when`] for more details. diff --git a/tower-http/src/compression/mod.rs b/tower-http/src/compression/mod.rs index 7f7c143f..db3143e7 100644 --- a/tower-http/src/compression/mod.rs +++ b/tower-http/src/compression/mod.rs @@ -111,7 +111,7 @@ mod tests { } #[tokio::test] - async fn works() { + async fn gzip_works() { let svc = service_fn(handle); let mut svc = Compression::new(svc).compress_when(Always); @@ -141,6 +141,34 @@ mod tests { assert_eq!(decompressed, "Hello, World!"); } + #[tokio::test] + async fn zstd_works() { + let svc = service_fn(handle); + let mut svc = Compression::new(svc).compress_when(Always); + + // call the service + let req = Request::builder() + .header("accept-encoding", "zstd") + .body(Body::empty()) + .unwrap(); + let res = svc.ready().await.unwrap().call(req).await.unwrap(); + + // read the compressed body + let mut body = res.into_body(); + let mut data = BytesMut::new(); + while let Some(chunk) = body.data().await { + let chunk = chunk.unwrap(); + data.extend_from_slice(&chunk[..]); + } + let compressed_data = data.freeze().to_vec(); + + // decompress the body + let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap(); + let decompressed = String::from_utf8(decompressed).unwrap(); + + assert_eq!(decompressed, "Hello, World!"); + } + #[allow(dead_code)] async fn is_compatible_with_hyper() { let svc = service_fn(handle); diff --git a/tower-http/src/compression/service.rs b/tower-http/src/compression/service.rs index 91de2b50..17701456 100644 --- a/tower-http/src/compression/service.rs +++ b/tower-http/src/compression/service.rs @@ -61,6 +61,13 @@ impl Compression { self } + /// Sets whether to enable the Zstd encoding. + #[cfg(feature = "compression-zstd")] + pub fn zstd(mut self, enable: bool) -> Self { + self.accept.set_zstd(enable); + self + } + /// Disables the gzip encoding. /// /// This method is available even if the `gzip` crate feature is disabled. @@ -85,6 +92,14 @@ impl Compression { self } + /// Disables the Zstd encoding. + /// + /// This method is available even if the `zstd` crate feature is disabled. + pub fn no_zstd(mut self) -> Self { + self.accept.set_zstd(false); + self + } + /// Replace the current compression predicate. /// /// Predicates are used to determine whether a response should be compressed or not. diff --git a/tower-http/src/compression_utils.rs b/tower-http/src/compression_utils.rs index b0c05726..15b63f1f 100644 --- a/tower-http/src/compression_utils.rs +++ b/tower-http/src/compression_utils.rs @@ -20,20 +20,29 @@ pub(crate) struct AcceptEncoding { pub(crate) gzip: bool, pub(crate) deflate: bool, pub(crate) br: bool, + pub(crate) zstd: bool, } impl AcceptEncoding { #[allow(dead_code)] pub(crate) fn to_header_value(self) -> Option { - let accept = match (self.gzip(), self.deflate(), self.br()) { - (true, true, true) => "gzip,deflate,br", - (true, true, false) => "gzip,deflate", - (true, false, true) => "gzip,br", - (true, false, false) => "gzip", - (false, true, true) => "deflate,br", - (false, true, false) => "deflate", - (false, false, true) => "br", - (false, false, false) => return None, + let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) { + (true, true, true, false) => "gzip,deflate,br", + (true, true, false, false) => "gzip,deflate", + (true, false, true, false) => "gzip,br", + (true, false, false, false) => "gzip", + (false, true, true, false) => "deflate,br", + (false, true, false, false) => "deflate", + (false, false, true, false) => "br", + (true, true, true, true) => "zstd,gzip,deflate,br", + (true, true, false, true) => "zstd,gzip,deflate", + (true, false, true, true) => "zstd,gzip,br", + (true, false, false, true) => "zstd,gzip", + (false, true, true, true) => "zstd,deflate,br", + (false, true, false, true) => "zstd,deflate", + (false, false, true, true) => "zstd,br", + (false, false, false, true) => "zstd", + (false, false, false, false) => return None, }; Some(HeaderValue::from_static(accept)) } @@ -52,6 +61,11 @@ impl AcceptEncoding { pub(crate) fn set_br(&mut self, enable: bool) { self.br = enable; } + + #[allow(dead_code)] + pub(crate) fn set_zstd(&mut self, enable: bool) { + self.zstd = enable; + } } impl SupportedEncodings for AcceptEncoding { @@ -90,6 +104,18 @@ impl SupportedEncodings for AcceptEncoding { false } } + + #[allow(dead_code)] + fn zstd(&self) -> bool { + #[cfg(any(feature = "decompression-zstd", feature = "compression-zstd"))] + { + self.zstd + } + #[cfg(not(any(feature = "decompression-zstd", feature = "compression-zstd")))] + { + false + } + } } impl Default for AcceptEncoding { @@ -98,6 +124,7 @@ impl Default for AcceptEncoding { gzip: true, deflate: true, br: true, + zstd: true, } } } diff --git a/tower-http/src/content_encoding.rs b/tower-http/src/content_encoding.rs index f609c7d6..c962d0ee 100644 --- a/tower-http/src/content_encoding.rs +++ b/tower-http/src/content_encoding.rs @@ -2,6 +2,7 @@ pub(crate) trait SupportedEncodings: Copy { fn gzip(&self) -> bool; fn deflate(&self) -> bool; fn br(&self) -> bool; + fn zstd(&self) -> bool; } // This enum's variants are ordered from least to most preferred. @@ -15,6 +16,8 @@ pub(crate) enum Encoding { Gzip, #[cfg(any(feature = "fs", feature = "compression-br"))] Brotli, + #[cfg(any(feature = "fs", feature = "compression-zstd"))] + Zstd, } impl Encoding { @@ -27,6 +30,8 @@ impl Encoding { Encoding::Deflate => "deflate", #[cfg(any(feature = "fs", feature = "compression-br"))] Encoding::Brotli => "br", + #[cfg(any(feature = "fs", feature = "compression-zstd"))] + Encoding::Zstd => "zstd", Encoding::Identity => "identity", } } @@ -37,6 +42,7 @@ impl Encoding { Encoding::Gzip => Some(std::ffi::OsStr::new(".gz")), Encoding::Deflate => Some(std::ffi::OsStr::new(".zz")), Encoding::Brotli => Some(std::ffi::OsStr::new(".br")), + Encoding::Zstd => Some(std::ffi::OsStr::new(".zst")), Encoding::Identity => None, } } @@ -50,6 +56,7 @@ impl Encoding { feature = "compression-gzip", feature = "compression-br", feature = "compression-deflate", + feature = "compression-zstd", feature = "fs", ))] fn parse(s: &str, _supported_encoding: impl SupportedEncodings) -> Option { @@ -68,6 +75,11 @@ impl Encoding { return Some(Encoding::Brotli); } + #[cfg(any(feature = "fs", feature = "compression-zstd"))] + if s.eq_ignore_ascii_case("zstd") && _supported_encoding.zstd() { + return Some(Encoding::Zstd); + } + if s.eq_ignore_ascii_case("identity") { return Some(Encoding::Identity); } @@ -78,6 +90,7 @@ impl Encoding { #[cfg(any( feature = "compression-gzip", feature = "compression-br", + feature = "compression-zstd", feature = "compression-deflate", ))] // based on https://github.com/http-rs/accept-encoding @@ -92,6 +105,7 @@ impl Encoding { #[cfg(any( feature = "compression-gzip", feature = "compression-br", + feature = "compression-zstd", feature = "compression-deflate", feature = "fs", ))] @@ -109,6 +123,7 @@ impl Encoding { #[cfg(any( feature = "compression-gzip", feature = "compression-br", + feature = "compression-zstd", feature = "compression-deflate", feature = "fs", ))] @@ -118,6 +133,7 @@ pub(crate) struct QValue(u16); #[cfg(any( feature = "compression-gzip", feature = "compression-br", + feature = "compression-zstd", feature = "compression-deflate", feature = "fs", ))] @@ -189,6 +205,7 @@ impl QValue { #[cfg(any( feature = "compression-gzip", feature = "compression-br", + feature = "compression-zstd", feature = "compression-deflate", feature = "fs", ))] @@ -225,7 +242,8 @@ pub(crate) fn encodings( test, feature = "compression-gzip", feature = "compression-deflate", - feature = "compression-br" + feature = "compression-br", + feature = "compression-zstd", ))] mod tests { use super::*; @@ -245,6 +263,10 @@ mod tests { fn br(&self) -> bool { true } + + fn zstd(&self) -> bool { + true + } } #[test] diff --git a/tower-http/src/decompression/body.rs b/tower-http/src/decompression/body.rs index 1e6f26a8..7c908bc7 100644 --- a/tower-http/src/decompression/body.rs +++ b/tower-http/src/decompression/body.rs @@ -10,6 +10,8 @@ use async_compression::tokio::bufread::BrotliDecoder; use async_compression::tokio::bufread::GzipDecoder; #[cfg(feature = "decompression-deflate")] use async_compression::tokio::bufread::ZlibDecoder; +#[cfg(feature = "decompression-zstd")] +use async_compression::tokio::bufread::ZstdDecoder; use bytes::{Buf, Bytes}; use futures_util::ready; use http::HeaderMap; @@ -49,6 +51,8 @@ where BodyInner::Deflate { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), #[cfg(feature = "decompression-br")] BodyInner::Brotli { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), + #[cfg(feature = "decompression-zstd")] + BodyInner::Zstd { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), BodyInner::Identity { inner } => inner, // FIXME: Remove once possible; see https://github.com/rust-lang/rust/issues/51085 @@ -58,6 +62,8 @@ where BodyInner::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInner::Brotli { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-zstd"))] + BodyInner::Zstd { inner } => match inner.0 {}, } } @@ -70,6 +76,8 @@ where BodyInner::Deflate { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), #[cfg(feature = "decompression-br")] BodyInner::Brotli { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), + #[cfg(feature = "decompression-zstd")] + BodyInner::Zstd { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), BodyInner::Identity { inner } => inner, #[cfg(not(feature = "decompression-gzip"))] @@ -78,6 +86,8 @@ where BodyInner::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInner::Brotli { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-zstd"))] + BodyInner::Zstd { inner } => match inner.0 {}, } } @@ -108,6 +118,14 @@ where .get_pin_mut() .get_pin_mut() .get_pin_mut(), + #[cfg(feature = "decompression-zstd")] + BodyInnerProj::Zstd { inner } => inner + .project() + .read + .get_pin_mut() + .get_pin_mut() + .get_pin_mut() + .get_pin_mut(), BodyInnerProj::Identity { inner } => inner, #[cfg(not(feature = "decompression-gzip"))] @@ -116,6 +134,8 @@ where BodyInnerProj::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInnerProj::Brotli { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-zstd"))] + BodyInnerProj::Zstd { inner } => match inner.0 {}, } } @@ -143,6 +163,13 @@ where .into_inner() .into_inner() .into_inner(), + #[cfg(feature = "decompression-zstd")] + BodyInner::Zstd { inner } => inner + .read + .into_inner() + .into_inner() + .into_inner() + .into_inner(), BodyInner::Identity { inner } => inner, #[cfg(not(feature = "decompression-gzip"))] @@ -151,6 +178,8 @@ where BodyInner::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInner::Brotli { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-zstd"))] + BodyInner::Zstd { inner } => match inner.0 {}, } } } @@ -158,7 +187,8 @@ where #[cfg(any( not(feature = "decompression-gzip"), not(feature = "decompression-deflate"), - not(feature = "decompression-br") + not(feature = "decompression-br"), + not(feature = "decompression-zstd") ))] pub(crate) enum Never {} @@ -177,6 +207,11 @@ type BrotliBody = WrapBody>; #[cfg(not(feature = "decompression-br"))] type BrotliBody = (Never, PhantomData); +#[cfg(feature = "decompression-zstd")] +type ZstdBody = WrapBody>; +#[cfg(not(feature = "decompression-zstd"))] +type ZstdBody = (Never, PhantomData); + pin_project! { #[project = BodyInnerProj] pub(crate) enum BodyInner @@ -195,6 +230,10 @@ pin_project! { #[pin] inner: BrotliBody, }, + Zstd { + #[pin] + inner: ZstdBody, + }, Identity { #[pin] inner: B, @@ -218,6 +257,11 @@ impl BodyInner { Self::Brotli { inner } } + #[cfg(feature = "decompression-zstd")] + pub(crate) fn zstd(inner: WrapBody>) -> Self { + Self::Zstd { inner } + } + pub(crate) fn identity(inner: B) -> Self { Self::Identity { inner } } @@ -242,6 +286,8 @@ where BodyInnerProj::Deflate { inner } => inner.poll_data(cx), #[cfg(feature = "decompression-br")] BodyInnerProj::Brotli { inner } => inner.poll_data(cx), + #[cfg(feature = "decompression-zstd")] + BodyInnerProj::Zstd { inner } => inner.poll_data(cx), BodyInnerProj::Identity { inner } => match ready!(inner.poll_data(cx)) { Some(Ok(mut buf)) => { let bytes = buf.copy_to_bytes(buf.remaining()); @@ -257,6 +303,8 @@ where BodyInnerProj::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInnerProj::Brotli { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-zstd"))] + BodyInnerProj::Zstd { inner } => match inner.0 {}, } } @@ -271,6 +319,8 @@ where BodyInnerProj::Deflate { inner } => inner.poll_trailers(cx), #[cfg(feature = "decompression-br")] BodyInnerProj::Brotli { inner } => inner.poll_trailers(cx), + #[cfg(feature = "decompression-zstd")] + BodyInnerProj::Zstd { inner } => inner.poll_trailers(cx), BodyInnerProj::Identity { inner } => inner.poll_trailers(cx).map_err(Into::into), #[cfg(not(feature = "decompression-gzip"))] @@ -279,6 +329,8 @@ where BodyInnerProj::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInnerProj::Brotli { inner } => match inner.0 {}, + #[cfg(not(feature = "decompression-zstd"))] + BodyInnerProj::Zstd { inner } => match inner.0 {}, } } } @@ -333,3 +385,20 @@ where pinned.get_pin_mut() } } + +#[cfg(feature = "decompression-zstd")] +impl DecorateAsyncRead for ZstdDecoder +where + B: Body, +{ + type Input = AsyncReadBody; + type Output = ZstdDecoder; + + fn apply(input: Self::Input) -> Self::Output { + ZstdDecoder::new(input) + } + + fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { + pinned.get_pin_mut() + } +} diff --git a/tower-http/src/decompression/future.rs b/tower-http/src/decompression/future.rs index a62892d2..43d9044a 100644 --- a/tower-http/src/decompression/future.rs +++ b/tower-http/src/decompression/future.rs @@ -55,6 +55,11 @@ where DecompressionBody::new(BodyInner::brotli(WrapBody::new(body))) } + #[cfg(feature = "decompression-zstd")] + b"zstd" if self.accept.zstd() => { + DecompressionBody::new(BodyInner::zstd(WrapBody::new(body))) + } + _ => { return Poll::Ready(Ok(Response::from_parts( parts, diff --git a/tower-http/src/decompression/layer.rs b/tower-http/src/decompression/layer.rs index 2c1c8911..4a184c16 100644 --- a/tower-http/src/decompression/layer.rs +++ b/tower-http/src/decompression/layer.rs @@ -51,6 +51,13 @@ impl DecompressionLayer { self } + /// Sets whether to request the Zstd encoding. + #[cfg(feature = "decompression-zstd")] + pub fn zstd(mut self, enable: bool) -> Self { + self.accept.set_zstd(enable); + self + } + /// Disables the gzip encoding. /// /// This method is available even if the `gzip` crate feature is disabled. @@ -74,4 +81,12 @@ impl DecompressionLayer { self.accept.set_br(false); self } + + /// Disables the Zstd encoding. + /// + /// This method is available even if the `zstd` crate feature is disabled. + pub fn no_zstd(mut self) -> Self { + self.accept.set_zstd(false); + self + } } diff --git a/tower-http/src/decompression/service.rs b/tower-http/src/decompression/service.rs index a61b5c7e..50e8ead5 100644 --- a/tower-http/src/decompression/service.rs +++ b/tower-http/src/decompression/service.rs @@ -59,6 +59,13 @@ impl Decompression { self } + /// Sets whether to request the Zstd encoding. + #[cfg(feature = "decompression-zstd")] + pub fn zstd(mut self, enable: bool) -> Self { + self.accept.set_zstd(enable); + self + } + /// Disables the gzip encoding. /// /// This method is available even if the `gzip` crate feature is disabled. @@ -82,6 +89,14 @@ impl Decompression { self.accept.set_br(false); self } + + /// Disables the Zstd encoding. + /// + /// This method is available even if the `zstd` crate feature is disabled. + pub fn no_zstd(mut self) -> Self { + self.accept.set_zstd(false); + self + } } impl Service> for Decompression diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index d1d39073..e6c4186f 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -237,7 +237,8 @@ pub mod propagate_header; #[cfg(any( feature = "compression-br", feature = "compression-deflate", - feature = "compression-gzip" + feature = "compression-gzip", + feature = "compression-zstd", ))] pub mod compression; @@ -250,7 +251,8 @@ pub mod sensitive_headers; #[cfg(any( feature = "decompression-br", feature = "decompression-deflate", - feature = "decompression-gzip" + feature = "decompression-gzip", + feature = "decompression-zstd", ))] pub mod decompression; @@ -258,9 +260,11 @@ pub mod decompression; feature = "compression-br", feature = "compression-deflate", feature = "compression-gzip", + feature = "compression-zstd", feature = "decompression-br", feature = "decompression-deflate", feature = "decompression-gzip", + feature = "decompression-zstd", feature = "fs" // Used for serving precompressed static files as well ))] mod content_encoding; @@ -269,9 +273,11 @@ mod content_encoding; feature = "compression-br", feature = "compression-deflate", feature = "compression-gzip", + feature = "compression-zstd", feature = "decompression-br", feature = "decompression-deflate", feature = "decompression-gzip", + feature = "decompression-zstd", ))] mod compression_utils; diff --git a/tower-http/src/services/fs/serve_dir/mod.rs b/tower-http/src/services/fs/serve_dir/mod.rs index 7e0f7348..4fac77a2 100644 --- a/tower-http/src/services/fs/serve_dir/mod.rs +++ b/tower-http/src/services/fs/serve_dir/mod.rs @@ -182,6 +182,23 @@ impl ServeDir { self } + /// Informs the service that it should also look for a precompressed zstd + /// version of _any_ file in the directory. + /// + /// Assuming the `dir` directory is being served and `dir/foo.txt` is requested, + /// a client with an `Accept-Encoding` header that allows the zstd encoding + /// will receive the file `dir/foo.txt.zst` instead of `dir/foo.txt`. + /// If the precompressed file is not available, or the client doesn't support it, + /// the uncompressed version will be served instead. + /// Both the precompressed version and the uncompressed version are expected + /// to be present in the directory. Different precompressed variants can be combined. + pub fn precompressed_zstd(mut self) -> Self { + self.precompressed_variants + .get_or_insert(Default::default()) + .zstd = true; + self + } + /// Set the fallback service. /// /// This service will be called if there is no file at the path of the request. @@ -534,6 +551,7 @@ struct PrecompressedVariants { gzip: bool, deflate: bool, br: bool, + zstd: bool, } impl SupportedEncodings for PrecompressedVariants { @@ -548,4 +566,8 @@ impl SupportedEncodings for PrecompressedVariants { fn br(&self) -> bool { self.br } + + fn zstd(&self) -> bool { + self.zstd + } }