From 2d267b56ef2e50981c90a056292e519810047084 Mon Sep 17 00:00:00 2001 From: Daniel Vigovszky Date: Sat, 13 Jan 2024 20:20:01 +0100 Subject: [PATCH] Reproducer and fix for multipart/form-data bug (#2468) * Reproducer for multipart/form-data bug * Fix * Fix * Fix 2.12 --- .../main/scala/zio/http/StreamingForm.scala | 18 +++-- .../http/codec/internal/EncoderDecoder.scala | 21 +++--- .../zio/http/endpoint/RoundtripSpec.scala | 70 ++++++++++++++++++- 3 files changed, 91 insertions(+), 18 deletions(-) diff --git a/zio-http/src/main/scala/zio/http/StreamingForm.scala b/zio-http/src/main/scala/zio/http/StreamingForm.scala index 859f4085c..cc4bbf14b 100644 --- a/zio-http/src/main/scala/zio/http/StreamingForm.scala +++ b/zio-http/src/main/scala/zio/http/StreamingForm.scala @@ -49,8 +49,9 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary: for { runtime <- ZIO.runtime[Any] buffer <- ZIO.succeed(new Buffer(bufferSize)) + abort <- Promise.make[Nothing, Unit] fieldQueue <- Queue.bounded[Take[Throwable, FormField]](4) - reader = + reader = source .mapAccumImmediate(initialState) { (state, byte) => state.formState match { @@ -60,7 +61,7 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary: case Some(queue) => val takes = buffer.addByte(crlfBoundary, byte) if (takes.nonEmpty) { - runtime.unsafe.run(queue.offerAll(takes)).getOrThrowFiberFailure() + runtime.unsafe.run(queue.offerAll(takes).raceFirst(abort.await)).getOrThrowFiberFailure() } case None => } @@ -142,11 +143,16 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary: } _ <- reader.runDrain.catchAllCause { cause => fieldQueue.offer(Take.failCause(cause)) + }.ensuring( + fieldQueue.offer(Take.end), + ).forkScoped + .interruptible + _ <- Scope.addFinalizerExit { exit => + // If the fieldStream fails, we need to make sure the reader stream can be interrupted, as it may be blocked + // in the unsafe.run(queue.offer) call (interruption does not propagate into the unsafe.run). This is implemented + // by setting the abort promise which is raced within the unsafe run when offering the element to the queue. + abort.succeed(()).when(exit.isFailure) } - .ensuring( - fieldQueue.offer(Take.end), - ) - .forkScoped fieldStream = ZStream.fromQueue(fieldQueue).flattenTake } yield fieldStream } diff --git a/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala b/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala index bb7afab79..a5b778aaa 100644 --- a/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala +++ b/zio-http/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala @@ -385,7 +385,7 @@ private[codec] object EncoderDecoder { private def processStreamingForm(form: StreamingForm, inputs: Array[Any])(implicit trace: Trace, ): ZIO[Any, Throwable, Unit] = - Promise.make[HttpCodecError, Unit].flatMap { ready => + Promise.make[Throwable, Unit].flatMap { ready => form.fields.mapZIO { field => indexByName.get(field.name) match { case Some(idx) => @@ -404,20 +404,19 @@ private[codec] object EncoderDecoder { ZIO.unit case _ => formFieldDecoders(idx)(field).map { result => inputs(idx) = result } - }) *> - ready - .succeed(()) - .unless( - inputs.exists(_ == null), - ) // Marking as ready so the handler can start consuming the streaming field before this stream ends + }) + .zipRight( + ready + .succeed(()) + .unless( + inputs.exists(_ == null), + ), // Marking as ready so the handler can start consuming the streaming field before this stream ends + ) case None => ready.fail(HttpCodecError.MalformedBody(s"Unexpected multipart/form-data field: ${field.name}")) } }.runDrain - .zipRight( - ready - .succeed(()), - ) + .intoPromise(ready) .forkDaemon .zipRight( ready.await, diff --git a/zio-http/src/test/scala/zio/http/endpoint/RoundtripSpec.scala b/zio-http/src/test/scala/zio/http/endpoint/RoundtripSpec.scala index eff5ce37a..3cd8eed57 100644 --- a/zio-http/src/test/scala/zio/http/endpoint/RoundtripSpec.scala +++ b/zio-http/src/test/scala/zio/http/endpoint/RoundtripSpec.scala @@ -16,12 +16,14 @@ package zio.http.endpoint +import java.time.Instant + import zio._ import zio.test.Assertion._ import zio.test.TestAspect._ import zio.test._ -import zio.stream.ZStream +import zio.stream.{Take, ZStream} import zio.schema.{DeriveSchema, Schema} @@ -30,6 +32,7 @@ import zio.http.Method._ import zio.http._ import zio.http.codec.HttpCodec.{authorization, query} import zio.http.codec.{Doc, HeaderCodec, HttpCodec, QueryCodec} +import zio.http.endpoint.EndpointSpec.ImageMetadata import zio.http.netty.server.NettyDriver object RoundtripSpec extends ZIOHttpSpec { @@ -95,6 +98,20 @@ object RoundtripSpec extends ZIOHttpSpec { result <- outF(out) } yield result + def testEndpointCustomRequestZIO[P, In, Err, Out]( + route: Routes[Any, Nothing], + in: Request, + outF: Response => ZIO[Any, Err, TestResult], + ): zio.ZIO[Server with Client with Scope, Err, TestResult] = + ZIO.scoped[Client with Server] { + for { + port <- Server.install(route.toHttpApp @@ Middleware.requestLogging()) + client <- ZIO.service[Client] + out <- client.request(in.updateURL(_.host("localhost").port(port))).orDie + result <- outF(out) + } yield result + } + def testEndpointError[P, In, Err, Out]( endpoint: Endpoint[P, In, Err, Out, EndpointMiddleware.None.type], route: Routes[Any, Nothing], @@ -461,6 +478,57 @@ object RoundtripSpec extends ZIOHttpSpec { ) } }, + test("multi-part input with stream and invalid json field") { + val api = Endpoint(POST / "test") + .in[String]("name") + .in[ImageMetadata]("metadata") + .inStream[Byte]("file") + .out[String] + + val route = api.implement { + Handler.fromFunctionZIO { case (name, metadata, file) => + file.runCount.map { n => + s"name: $name, metadata: $metadata, count: $n" + } + } + } + + Random + .nextBytes(1024 * 1024) + .flatMap { bytes => + testEndpointCustomRequestZIO( + Routes(route), + Request.post( + "/test", + Body.fromMultipartForm( + Form( + FormField.textField("name", """"xyz"""", MediaType.application.`json`), + FormField.textField( + "metadata", + """{"description": "sample description", "modifiedAt": "2023-10-02T10:30:00.00Z"}""", + MediaType.application.`json`, + ), + FormField.streamingBinaryField( + "file", + ZStream.fromChunk(bytes).rechunk(1000), + MediaType.application.`octet-stream`, + ), + ), + Boundary("bnd1234"), + ), + ), + response => + response.body.asString.map(s => + assertTrue( + s == """"name: xyz, metadata: ImageMetadata(sample description,2023-10-02T10:30:00Z), count: 1048576"""", + ), + ), + ) + } + .map { r => + assert(r.isFailure)(isTrue) // We expect it to fail but complete + } + }, ).provide( Server.live, ZLayer.succeed(Server.Config.default.onAnyOpenPort.enableRequestStreaming),