Skip to content

Commit

Permalink
Avoid unnecessary parsing of content-type header in `ServerInboundH…
Browse files Browse the repository at this point in the history
…andler` (#2644)

* Parse content-type header only if necessary

* Empty

* PR comments
  • Loading branch information
kyri-petrou authored Jan 24, 2024
1 parent 5c9b8c7 commit 351f6ed
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ class ProbeContentTypeBenchmark {
()
}

@Benchmark
def benchmarkParseMediaTypeSimple(): Unit = {
MediaType.forContentType("application/json")
()
}

@Benchmark
def benchmarkParseMediaTypeNotLowerCase(): Unit = {
MediaType.forContentType("Application/json")
()
}

@Benchmark
def benchmarkParseMediaTypeWithParams(): Unit = {
MediaType.forContentType("application/json; charset=utf-8")
()
}

@Benchmark
def benchmarkParseContentType(): Unit = {
ContentType.parse("application/json; charset=utf-8")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class ServerInboundHandlerBenchmark {
private val testUrl = s"$baseUrl/$testEndPoint"
private val testRequest = basicRequest.get(uri"$testUrl")

private val testContentTypeRequest = testRequest.contentType("application/json; charset=utf8")

private val shutdownResponse = Response.text("shutting down")
private val shutdownEndpoint = "shutdown"
private val shutdownUrl = s"http://localhost:8080/$shutdownEndpoint"
Expand Down Expand Up @@ -107,4 +109,11 @@ class ServerInboundHandlerBenchmark {
if (!statusCode.isSuccess)
throw new RuntimeException(s"Received unexpected status code ${statusCode.code}")
}

@Benchmark
def benchmarkSimpleContentType(): Unit = {
val statusCode = testContentTypeRequest.send(backend).code
if (!statusCode.isSuccess)
throw new RuntimeException(s"Received unexpected status code ${statusCode.code}")
}
}
35 changes: 23 additions & 12 deletions zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,30 @@ object NettyBody extends BodyEncoding {
private[zio] def fromAsync(
unsafeAsync: UnsafeAsync => Unit,
knownContentLength: Option[Long],
contentTypeHeader: Option[Header.ContentType] = None,
): Body = AsyncBody(
unsafeAsync,
knownContentLength,
contentTypeHeader.map(_.mediaType),
contentTypeHeader.flatMap(_.boundary),
)
contentTypeHeader: Option[String] = None,
): Body = {
val (mediaType, boundary) = mediaTypeAndBoundary(contentTypeHeader)
AsyncBody(
unsafeAsync,
knownContentLength,
mediaType,
boundary,
)
}

/**
* Helper to create Body from ByteBuf
*/
def fromByteBuf(byteBuf: ByteBuf, contentTypeHeader: Option[Header.ContentType] = None): Body =
ByteBufBody(byteBuf, contentTypeHeader.map(_.mediaType), contentTypeHeader.flatMap(_.boundary))
private[zio] def fromByteBuf(byteBuf: ByteBuf, contentTypeHeader: Option[String]): Body = {
val (mediaType, boundary) = mediaTypeAndBoundary(contentTypeHeader)
ByteBufBody(byteBuf, mediaType, boundary)
}

private def mediaTypeAndBoundary(contentTypeHeader: Option[String]) = {
val mediaType = contentTypeHeader.flatMap(MediaType.forContentType)
val boundary = mediaType.flatMap(_.parameters.get("boundary")).map(Boundary(_))
(mediaType, boundary)
}

override def fromCharSequence(charSequence: CharSequence, charset: Charset): Body =
fromAsciiString(new AsciiString(charSequence, charset))
Expand Down Expand Up @@ -84,7 +95,7 @@ object NettyBody extends BodyEncoding {
override def contentType(newMediaType: MediaType): Body = copy(mediaType = Some(newMediaType))

override def contentType(newMediaType: MediaType, newBoundary: Boundary): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(Some(newBoundary)))
copy(mediaType = Some(newMediaType), boundary = Some(newBoundary))

override def knownContentLength: Option[Long] = Some(asciiString.length().toLong)
}
Expand Down Expand Up @@ -116,7 +127,7 @@ object NettyBody extends BodyEncoding {
override def contentType(newMediaType: MediaType): Body = copy(mediaType = Some(newMediaType))

override def contentType(newMediaType: MediaType, newBoundary: Boundary): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(Some(newBoundary)))
copy(mediaType = Some(newMediaType), boundary = Some(newBoundary))

override def knownContentLength: Option[Long] = Some(byteBuf.readableBytes().toLong)
}
Expand Down Expand Up @@ -162,7 +173,7 @@ object NettyBody extends BodyEncoding {
override def contentType(newMediaType: MediaType): Body = copy(mediaType = Some(newMediaType))

override def contentType(newMediaType: MediaType, newBoundary: Boundary): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(Some(newBoundary)))
copy(mediaType = Some(newMediaType), boundary = Some(newBoundary))
}

private[zio] trait UnsafeAsync {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object NettyResponse {
val status = Conversions.statusFromNetty(jRes.status())
val headers = Conversions.headersFromNetty(jRes.headers())
val copiedBuffer = Unpooled.copiedBuffer(jRes.content())
val data = NettyBody.fromByteBuf(copiedBuffer, headers.header(Header.ContentType))
val data = NettyBody.fromByteBuf(copiedBuffer, headers.headers.get(Header.ContentType.name))

Response(status, headers, data)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,13 @@ private[zio] final case class ServerInboundHandler(
case _ => None
}

val headers = Conversions.headersFromNetty(nettyReq.headers())
val contentType = headers.header(Header.ContentType)
val headers = Conversions.headersFromNetty(nettyReq.headers())
val contentTypeHeader = headers.headers.get(Header.ContentType.name)

nettyReq match {
case nettyReq: FullHttpRequest =>
Request(
body = NettyBody.fromByteBuf(
nettyReq.content(),
contentType,
),
body = NettyBody.fromByteBuf(nettyReq.content(), contentTypeHeader),
headers = headers,
method = Conversions.methodFromNetty(nettyReq.method()),
url = URL.decode(nettyReq.uri()).getOrElse(URL.empty),
Expand All @@ -229,7 +226,7 @@ private[zio] final case class ServerInboundHandler(
case nettyReq: HttpRequest =>
val knownContentLength = headers.get(Header.ContentLength).map(_.length)
val handler = addAsyncBodyHandler(ctx)
val body = NettyBody.fromAsync(async => handler.connect(async), knownContentLength, contentType)
val body = NettyBody.fromAsync(async => handler.connect(async), knownContentLength, contentTypeHeader)

Request(
body = body,
Expand Down
11 changes: 11 additions & 0 deletions zio-http/jvm/src/test/scala/zio/http/MediaTypeSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ object MediaTypeSpec extends ZIOHttpSpec {
test("predefined mime type parsing") {
assertTrue(MediaType.forContentType("application/json").contains(application.`json`))
},
test("with boundary") {
// NOTE: Testing with non-lowercase values on purpose as spec requires MIME type and param keys to be case-insensitive,and param values case-sensitive
MediaType.forContentType("Multipart/form-data; Boundary=-A-") match {
case None => assertNever("failed to parse media type")
case Some(mt) =>
assertTrue(
mt.fullType == "multipart/form-data",
mt.parameters.get("boundary").contains("-A-"),
)
}
},
test("custom mime type parsing") {
assertTrue(MediaType.parseCustomMediaType("custom/mime").contains(MediaType("custom", "mime")))
},
Expand Down
36 changes: 18 additions & 18 deletions zio-http/shared/src/main/scala/zio/http/MediaType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,17 @@ object MediaType extends MediaTypes {
val mainTypeMap = allMediaTypes.map(m => m.mainType -> m).toMap

def forContentType(contentType: String): Option[MediaType] = {
val index = contentType.indexOf(";")
if (index == -1)
contentTypeMap.get(contentType).orElse(parseCustomMediaType(contentType))
else {
val index = contentType.indexOf(';')
if (index == -1) {
val contentTypeLC = contentType.toLowerCase
contentTypeMap.get(contentTypeLC).orElse(parseCustomMediaType(contentTypeLC))
} else {
val (contentType1, parameter) = contentType.splitAt(index)
val contentTypeLC = contentType1.toLowerCase
contentTypeMap
.get(contentType1)
.orElse(parseCustomMediaType(contentType1))
.map(_.copy(parameters = parseOptionalParameters(parameter.split(";"))))
.get(contentTypeLC)
.orElse(parseCustomMediaType(contentTypeLC))
.map(_.copy(parameters = parseOptionalParameters(parameter.split(';'))))
}
}

Expand All @@ -67,17 +69,15 @@ object MediaType extends MediaTypes {
}

private def parseOptionalParameters(parameters: Array[String]): Map[String, String] = {
@tailrec
def loop(parameters: Seq[String], parameterMap: Map[String, String]): Map[String, String] = parameters match {
case Seq(parameter, tail @ _*) =>
val parameterParts = parameter.split("=")
val newMap =
if (parameterParts.length == 2) parameterMap + (parameterParts.head -> parameterParts(1))
else parameterMap
loop(tail, newMap)
case _ => parameterMap
val builder = Map.newBuilder[String, String]
val size = parameters.length
var i = 0
while (i < size) {
val parameter = parameters(i)
val parts = parameter.split('=')
if (parts.length == 2) builder += ((parts(0).trim.toLowerCase, parts(1).trim))
i += 1
}

loop(parameters.toIndexedSeq, Map.empty).map { case (key, value) => key.trim -> value.trim }
builder.result()
}
}

0 comments on commit 351f6ed

Please sign in to comment.