Skip to content

Commit

Permalink
Reproducer and fix for multipart/form-data bug (#2468)
Browse files Browse the repository at this point in the history
* Reproducer for multipart/form-data bug

* Fix

* Fix

* Fix 2.12
  • Loading branch information
vigoo authored Jan 13, 2024
1 parent 97d1140 commit 2d267b5
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 18 deletions.
18 changes: 12 additions & 6 deletions zio-http/src/main/scala/zio/http/StreamingForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary:
for {
runtime <- ZIO.runtime[Any]
buffer <- ZIO.succeed(new Buffer(bufferSize))
abort <- Promise.make[Nothing, Unit]
fieldQueue <- Queue.bounded[Take[Throwable, FormField]](4)
reader =
reader =
source
.mapAccumImmediate(initialState) { (state, byte) =>
state.formState match {
Expand All @@ -60,7 +61,7 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary:
case Some(queue) =>
val takes = buffer.addByte(crlfBoundary, byte)
if (takes.nonEmpty) {
runtime.unsafe.run(queue.offerAll(takes)).getOrThrowFiberFailure()
runtime.unsafe.run(queue.offerAll(takes).raceFirst(abort.await)).getOrThrowFiberFailure()
}
case None =>
}
Expand Down Expand Up @@ -142,11 +143,16 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary:
}
_ <- reader.runDrain.catchAllCause { cause =>
fieldQueue.offer(Take.failCause(cause))
}.ensuring(
fieldQueue.offer(Take.end),
).forkScoped
.interruptible
_ <- Scope.addFinalizerExit { exit =>
// If the fieldStream fails, we need to make sure the reader stream can be interrupted, as it may be blocked
// in the unsafe.run(queue.offer) call (interruption does not propagate into the unsafe.run). This is implemented
// by setting the abort promise which is raced within the unsafe run when offering the element to the queue.
abort.succeed(()).when(exit.isFailure)
}
.ensuring(
fieldQueue.offer(Take.end),
)
.forkScoped
fieldStream = ZStream.fromQueue(fieldQueue).flattenTake
} yield fieldStream
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ private[codec] object EncoderDecoder {
private def processStreamingForm(form: StreamingForm, inputs: Array[Any])(implicit
trace: Trace,
): ZIO[Any, Throwable, Unit] =
Promise.make[HttpCodecError, Unit].flatMap { ready =>
Promise.make[Throwable, Unit].flatMap { ready =>
form.fields.mapZIO { field =>
indexByName.get(field.name) match {
case Some(idx) =>
Expand All @@ -404,20 +404,19 @@ private[codec] object EncoderDecoder {
ZIO.unit
case _ =>
formFieldDecoders(idx)(field).map { result => inputs(idx) = result }
}) *>
ready
.succeed(())
.unless(
inputs.exists(_ == null),
) // Marking as ready so the handler can start consuming the streaming field before this stream ends
})
.zipRight(
ready
.succeed(())
.unless(
inputs.exists(_ == null),
), // Marking as ready so the handler can start consuming the streaming field before this stream ends
)
case None =>
ready.fail(HttpCodecError.MalformedBody(s"Unexpected multipart/form-data field: ${field.name}"))
}
}.runDrain
.zipRight(
ready
.succeed(()),
)
.intoPromise(ready)
.forkDaemon
.zipRight(
ready.await,
Expand Down
70 changes: 69 additions & 1 deletion zio-http/src/test/scala/zio/http/endpoint/RoundtripSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

package zio.http.endpoint

import java.time.Instant

import zio._
import zio.test.Assertion._
import zio.test.TestAspect._
import zio.test._

import zio.stream.ZStream
import zio.stream.{Take, ZStream}

import zio.schema.{DeriveSchema, Schema}

Expand All @@ -30,6 +32,7 @@ import zio.http.Method._
import zio.http._
import zio.http.codec.HttpCodec.{authorization, query}
import zio.http.codec.{Doc, HeaderCodec, HttpCodec, QueryCodec}
import zio.http.endpoint.EndpointSpec.ImageMetadata
import zio.http.netty.server.NettyDriver

object RoundtripSpec extends ZIOHttpSpec {
Expand Down Expand Up @@ -95,6 +98,20 @@ object RoundtripSpec extends ZIOHttpSpec {
result <- outF(out)
} yield result

def testEndpointCustomRequestZIO[P, In, Err, Out](
route: Routes[Any, Nothing],
in: Request,
outF: Response => ZIO[Any, Err, TestResult],
): zio.ZIO[Server with Client with Scope, Err, TestResult] =
ZIO.scoped[Client with Server] {
for {
port <- Server.install(route.toHttpApp @@ Middleware.requestLogging())
client <- ZIO.service[Client]
out <- client.request(in.updateURL(_.host("localhost").port(port))).orDie
result <- outF(out)
} yield result
}

def testEndpointError[P, In, Err, Out](
endpoint: Endpoint[P, In, Err, Out, EndpointMiddleware.None.type],
route: Routes[Any, Nothing],
Expand Down Expand Up @@ -461,6 +478,57 @@ object RoundtripSpec extends ZIOHttpSpec {
)
}
},
test("multi-part input with stream and invalid json field") {
val api = Endpoint(POST / "test")
.in[String]("name")
.in[ImageMetadata]("metadata")
.inStream[Byte]("file")
.out[String]

val route = api.implement {
Handler.fromFunctionZIO { case (name, metadata, file) =>
file.runCount.map { n =>
s"name: $name, metadata: $metadata, count: $n"
}
}
}

Random
.nextBytes(1024 * 1024)
.flatMap { bytes =>
testEndpointCustomRequestZIO(
Routes(route),
Request.post(
"/test",
Body.fromMultipartForm(
Form(
FormField.textField("name", """"xyz"""", MediaType.application.`json`),
FormField.textField(
"metadata",
"""{"description": "sample description", "modifiedAt": "2023-10-02T10:30:00.00Z"}""",
MediaType.application.`json`,
),
FormField.streamingBinaryField(
"file",
ZStream.fromChunk(bytes).rechunk(1000),
MediaType.application.`octet-stream`,
),
),
Boundary("bnd1234"),
),
),
response =>
response.body.asString.map(s =>
assertTrue(
s == """"name: xyz, metadata: ImageMetadata(sample description,2023-10-02T10:30:00Z), count: 1048576"""",
),
),
)
}
.map { r =>
assert(r.isFailure)(isTrue) // We expect it to fail but complete
}
},
).provide(
Server.live,
ZLayer.succeed(Server.Config.default.onAnyOpenPort.enableRequestStreaming),
Expand Down

0 comments on commit 2d267b5

Please sign in to comment.