Skip to content

Commit

Permalink
Add deflate compression support
Browse files Browse the repository at this point in the history
  • Loading branch information
a1ien committed Jan 3, 2025
1 parent 5e9a5bc commit a8ddaf7
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tests/compression/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pin-project = "1.0"
prost = "0.13"
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]}
tokio-stream = "0.1"
tonic = {path = "../../tonic", features = ["gzip", "zstd"]}
tonic = {path = "../../tonic", features = ["gzip", "deflate", "zstd"]}
tower = "0.5"
tower-http = {version = "0.6", features = ["map-response-body", "map-request-body"]}

Expand Down
1 change: 1 addition & 0 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ version = "0.13.0"
[features]
codegen = ["dep:async-trait"]
gzip = ["dep:flate2"]
deflate = ["dep:flate2"]
zstd = ["dep:zstd"]
default = ["transport", "codegen", "prost"]
prost = ["dep:prost"]
Expand Down
2 changes: 1 addition & 1 deletion tonic/src/client/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ impl GrpcConfig {
.headers_mut()
.insert(CONTENT_TYPE, GRPC_CONTENT_TYPE);

#[cfg(any(feature = "gzip", feature = "zstd"))]
#[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
if let Some(encoding) = self.send_compression_encodings {
request.headers_mut().insert(
crate::codec::compression::ENCODING_HEADER,
Expand Down
51 changes: 41 additions & 10 deletions tonic/src/codec/compression.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::{metadata::MetadataValue, Status};
use bytes::{Buf, BufMut, BytesMut};
#[cfg(feature = "deflate")]
use flate2::read::{DeflateDecoder, DeflateEncoder};
#[cfg(feature = "gzip")]
use flate2::read::{GzDecoder, GzEncoder};
use std::fmt;
Expand All @@ -14,7 +16,7 @@ pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
/// Represents an ordered list of compression encodings that are enabled.
#[derive(Debug, Default, Clone, Copy)]
pub struct EnabledCompressionEncodings {
inner: [Option<CompressionEncoding>; 2],
inner: [Option<CompressionEncoding>; 3],
}

impl EnabledCompressionEncodings {
Expand Down Expand Up @@ -85,6 +87,9 @@ pub enum CompressionEncoding {
#[cfg(feature = "gzip")]
Gzip,
#[allow(missing_docs)]
#[cfg(feature = "deflate")]
Deflate,
#[allow(missing_docs)]
#[cfg(feature = "zstd")]
Zstd,
}
Expand All @@ -93,6 +98,8 @@ impl CompressionEncoding {
pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip,
#[cfg(feature = "deflate")]
CompressionEncoding::Deflate,
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd,
];
Expand All @@ -112,6 +119,8 @@ impl CompressionEncoding {
split_by_comma(header_value_str).find_map(|value| match value {
#[cfg(feature = "gzip")]
"gzip" => Some(CompressionEncoding::Gzip),
#[cfg(feature = "deflate")]
"deflate" => Some(CompressionEncoding::Deflate),
#[cfg(feature = "zstd")]
"zstd" => Some(CompressionEncoding::Zstd),
_ => None,
Expand All @@ -132,6 +141,10 @@ impl CompressionEncoding {
b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
Ok(Some(CompressionEncoding::Gzip))
}
#[cfg(feature = "deflate")]
b"deflate" if enabled_encodings.is_enabled(CompressionEncoding::Deflate) => {
Ok(Some(CompressionEncoding::Deflate))
}
#[cfg(feature = "zstd")]
b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
Ok(Some(CompressionEncoding::Zstd))
Expand Down Expand Up @@ -170,6 +183,8 @@ impl CompressionEncoding {
match self {
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip => "gzip",
#[cfg(feature = "deflate")]
CompressionEncoding::Deflate => "deflate",
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => "zstd",
}
Expand Down Expand Up @@ -217,6 +232,15 @@ pub(crate) fn compress(
);
std::io::copy(&mut gzip_encoder, &mut out_writer)?;
}
#[cfg(feature = "deflate")]
CompressionEncoding::Deflate => {
let mut deflate_encoder = DeflateEncoder::new(
&decompressed_buf[0..len],
// FIXME: support customizing the compression level
flate2::Compression::new(6),
);
std::io::copy(&mut deflate_encoder, &mut out_writer)?;
}
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => {
let mut zstd_encoder = Encoder::new(
Expand Down Expand Up @@ -247,7 +271,7 @@ pub(crate) fn decompress(
((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval;
out_buf.reserve(capacity);

#[cfg(any(feature = "gzip", feature = "zstd"))]
#[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
let mut out_writer = out_buf.writer();

match settings.encoding {
Expand All @@ -256,6 +280,11 @@ pub(crate) fn decompress(
let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
std::io::copy(&mut gzip_decoder, &mut out_writer)?;
}
#[cfg(feature = "deflate")]
CompressionEncoding::Deflate => {
let mut deflate_decoder = DeflateDecoder::new(&compressed_buf[0..len]);
std::io::copy(&mut deflate_decoder, &mut out_writer)?;
}
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => {
let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
Expand All @@ -282,7 +311,7 @@ pub enum SingleMessageCompressionOverride {

#[cfg(test)]
mod tests {
#[cfg(any(feature = "gzip", feature = "zstd"))]
#[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
use http::HeaderValue;

use super::*;
Expand All @@ -300,13 +329,13 @@ mod tests {
const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");

let encodings = EnabledCompressionEncodings {
inner: [Some(CompressionEncoding::Gzip), None],
inner: [Some(CompressionEncoding::Gzip), None, None],
};

assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);

let encodings = EnabledCompressionEncodings {
inner: [None, Some(CompressionEncoding::Gzip)],
inner: [None, None, Some(CompressionEncoding::Gzip)],
};

assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
Expand All @@ -318,43 +347,45 @@ mod tests {
const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");

let encodings = EnabledCompressionEncodings {
inner: [Some(CompressionEncoding::Zstd), None],
inner: [Some(CompressionEncoding::Zstd), None, None],
};

assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);

let encodings = EnabledCompressionEncodings {
inner: [None, Some(CompressionEncoding::Zstd)],
inner: [None, None, Some(CompressionEncoding::Zstd)],
};

assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
}

#[test]
#[cfg(all(feature = "gzip", feature = "zstd"))]
#[cfg(all(feature = "gzip", feature = "deflate", feature = "zstd"))]
fn convert_gzip_and_zstd_into_header_value() {
let encodings = EnabledCompressionEncodings {
inner: [
Some(CompressionEncoding::Gzip),
Some(CompressionEncoding::Deflate),
Some(CompressionEncoding::Zstd),
],
};

assert_eq!(
encodings.into_accept_encoding_header_value().unwrap(),
HeaderValue::from_static("gzip,zstd,identity"),
HeaderValue::from_static("gzip,deflate,zstd,identity"),
);

let encodings = EnabledCompressionEncodings {
inner: [
Some(CompressionEncoding::Zstd),
Some(CompressionEncoding::Deflate),
Some(CompressionEncoding::Gzip),
],
};

assert_eq!(
encodings.into_accept_encoding_header_value().unwrap(),
HeaderValue::from_static("zstd,gzip,identity"),
HeaderValue::from_static("zstd,deflate,gzip,identity"),
);
}
}
2 changes: 1 addition & 1 deletion tonic/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl<T> Response<T> {
/// **Note**: This only has effect on responses to unary requests and responses to client to
/// server streams. Response streams (server to client stream and bidirectional streams) will
/// still be compressed according to the configuration of the server.
#[cfg(feature = "gzip")]
#[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
pub fn disable_compression(&mut self) {
self.extensions_mut()
.insert(crate::codec::compression::SingleMessageCompressionOverride::Disable);
Expand Down
2 changes: 1 addition & 1 deletion tonic/src/server/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ where
.headers
.insert(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);

#[cfg(any(feature = "gzip", feature = "zstd"))]
#[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
if let Some(encoding) = accept_encoding {
// Set the content encoding
parts.headers.insert(
Expand Down

0 comments on commit a8ddaf7

Please sign in to comment.