Skip to content

Commit

Permalink
Add zstd compression support (#1532)
Browse files Browse the repository at this point in the history
* Implement zstd compression

* Parametrize compression tests

* add tests for accepting multiple encodings

* add some missing feature cfg for zstd

* make as_str only crate public

* make into_accept_encoding_header_value handle all combinations

* make decompress implementation consistent

* use zstd::stream::read::Encoder

* use default compression level for zstd

* fix rebase

* fix CI issue

---------

Co-authored-by: martinabeleda <[email protected]>
Co-authored-by: Quentin Perez <[email protected]>
Co-authored-by: Lucio Franco <[email protected]>
  • Loading branch information
4 people authored Nov 15, 2023
1 parent 53267a3 commit e8cb48f
Show file tree
Hide file tree
Showing 11 changed files with 560 additions and 125 deletions.
3 changes: 2 additions & 1 deletion tests/compression/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ bytes = "1"
http = "0.2"
http-body = "0.4"
hyper = "0.14.3"
paste = "1.0.12"
pin-project = "1.0"
prost = "0.12"
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]}
tokio-stream = "0.1"
tonic = {path = "../../tonic", features = ["gzip"]}
tonic = {path = "../../tonic", features = ["gzip", "zstd"]}
tower = {version = "0.4", features = []}
tower-http = {version = "0.4", features = ["map-response-body", "map-request-body"]}

Expand Down
54 changes: 43 additions & 11 deletions tests/compression/src/bidirectional_stream.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,45 @@
use super::*;
use http_body::Body;
use tonic::codec::CompressionEncoding;

#[tokio::test(flavor = "multi_thread")]
async fn client_enabled_server_enabled() {
util::parametrized_tests! {
client_enabled_server_enabled,
zstd: CompressionEncoding::Zstd,
gzip: CompressionEncoding::Gzip,
}

#[allow(dead_code)]
async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc = test_server::TestServer::new(Svc::default())
.accept_compressed(CompressionEncoding::Gzip)
.send_compressed(CompressionEncoding::Gzip);
.accept_compressed(encoding)
.send_compressed(encoding);

let request_bytes_counter = Arc::new(AtomicUsize::new(0));
let response_bytes_counter = Arc::new(AtomicUsize::new(0));

fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip");
req
#[derive(Clone)]
pub struct AssertRightEncoding {
encoding: CompressionEncoding,
}

#[allow(dead_code)]
impl AssertRightEncoding {
pub fn new(encoding: CompressionEncoding) -> Self {
Self { encoding }
}

pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
let expected = match self.encoding {
CompressionEncoding::Gzip => "gzip",
CompressionEncoding::Zstd => "zstd",
_ => panic!("unexpected encoding {:?}", self.encoding),
};
assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);

req
}
}

tokio::spawn({
Expand All @@ -24,7 +49,9 @@ async fn client_enabled_server_enabled() {
Server::builder()
.layer(
ServiceBuilder::new()
.map_request(assert_right_encoding)
.map_request(move |req| {
AssertRightEncoding::new(encoding).clone().call(req)
})
.layer(measure_request_body_size_layer(
request_bytes_counter.clone(),
))
Expand All @@ -44,8 +71,8 @@ async fn client_enabled_server_enabled() {
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await)
.send_compressed(CompressionEncoding::Gzip)
.accept_compressed(CompressionEncoding::Gzip);
.send_compressed(encoding)
.accept_compressed(encoding);

let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
Expand All @@ -56,7 +83,12 @@ async fn client_enabled_server_enabled() {
.await
.unwrap();

assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
let expected = match encoding {
CompressionEncoding::Gzip => "gzip",
CompressionEncoding::Zstd => "zstd",
_ => panic!("unexpected encoding {:?}", encoding),
};
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);

let mut stream: Streaming<SomeData> = res.into_inner();

Expand Down
108 changes: 81 additions & 27 deletions tests/compression/src/client_stream.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,42 @@
use super::*;
use http_body::Body as _;
use http_body::Body;
use tonic::codec::CompressionEncoding;

#[tokio::test(flavor = "multi_thread")]
async fn client_enabled_server_enabled() {
util::parametrized_tests! {
client_enabled_server_enabled,
zstd: CompressionEncoding::Zstd,
gzip: CompressionEncoding::Gzip,
}

#[allow(dead_code)]
async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc =
test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip);
let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);

let request_bytes_counter = Arc::new(AtomicUsize::new(0));

fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip");
req
#[derive(Clone)]
pub struct AssertRightEncoding {
encoding: CompressionEncoding,
}

#[allow(dead_code)]
impl AssertRightEncoding {
pub fn new(encoding: CompressionEncoding) -> Self {
Self { encoding }
}

pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
let expected = match self.encoding {
CompressionEncoding::Gzip => "gzip",
CompressionEncoding::Zstd => "zstd",
_ => panic!("unexpected encoding {:?}", self.encoding),
};
assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);

req
}
}

tokio::spawn({
Expand All @@ -22,7 +45,9 @@ async fn client_enabled_server_enabled() {
Server::builder()
.layer(
ServiceBuilder::new()
.map_request(assert_right_encoding)
.map_request(move |req| {
AssertRightEncoding::new(encoding).clone().call(req)
})
.layer(measure_request_body_size_layer(
request_bytes_counter.clone(),
))
Expand All @@ -35,8 +60,8 @@ async fn client_enabled_server_enabled() {
}
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await)
.send_compressed(CompressionEncoding::Gzip);
let mut client =
test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);

let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
Expand All @@ -48,12 +73,17 @@ async fn client_enabled_server_enabled() {
assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
}

#[tokio::test(flavor = "multi_thread")]
async fn client_disabled_server_enabled() {
util::parametrized_tests! {
client_disabled_server_enabled,
zstd: CompressionEncoding::Zstd,
gzip: CompressionEncoding::Gzip,
}

#[allow(dead_code)]
async fn client_disabled_server_enabled(encoding: CompressionEncoding) {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc =
test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip);
let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);

let request_bytes_counter = Arc::new(AtomicUsize::new(0));

Expand Down Expand Up @@ -93,8 +123,14 @@ async fn client_disabled_server_enabled() {
assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
}

#[tokio::test(flavor = "multi_thread")]
async fn client_enabled_server_disabled() {
util::parametrized_tests! {
client_enabled_server_disabled,
zstd: CompressionEncoding::Zstd,
gzip: CompressionEncoding::Gzip,
}

#[allow(dead_code)]
async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc = test_server::TestServer::new(Svc::default());
Expand All @@ -107,8 +143,8 @@ async fn client_enabled_server_disabled() {
.unwrap();
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await)
.send_compressed(CompressionEncoding::Gzip);
let mut client =
test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);

let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
Expand All @@ -117,18 +153,31 @@ async fn client_enabled_server_disabled() {
let status = client.compress_input_client_stream(req).await.unwrap_err();

assert_eq!(status.code(), tonic::Code::Unimplemented);
let expected = match encoding {
CompressionEncoding::Gzip => "gzip",
CompressionEncoding::Zstd => "zstd",
_ => panic!("unexpected encoding {:?}", encoding),
};
assert_eq!(
status.message(),
"Content is compressed with `gzip` which isn't supported"
format!(
"Content is compressed with `{}` which isn't supported",
expected
)
);
}

#[tokio::test(flavor = "multi_thread")]
async fn compressing_response_from_client_stream() {
util::parametrized_tests! {
compressing_response_from_client_stream,
zstd: CompressionEncoding::Zstd,
gzip: CompressionEncoding::Gzip,
}

#[allow(dead_code)]
async fn compressing_response_from_client_stream(encoding: CompressionEncoding) {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc =
test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip);
let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);

let response_bytes_counter = Arc::new(AtomicUsize::new(0));

Expand All @@ -153,13 +202,18 @@ async fn compressing_response_from_client_stream() {
}
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await)
.accept_compressed(CompressionEncoding::Gzip);
let mut client =
test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);

let req = Request::new(Box::pin(tokio_stream::empty()));

let res = client.compress_output_client_stream(req).await.unwrap();
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
let expected = match encoding {
CompressionEncoding::Gzip => "gzip",
CompressionEncoding::Zstd => "zstd",
_ => panic!("unexpected encoding {:?}", encoding),
};
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
let bytes_sent = response_bytes_counter.load(SeqCst);
assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
}
Loading

0 comments on commit e8cb48f

Please sign in to comment.