Skip to content

Commit

Permalink
Implement handleErrorCauseZIO for executing ZIO effects on error (#2513)
Browse files Browse the repository at this point in the history
* Implement handleErrorCauseZIO for executing side effects on error

* fix: scala3 macro error

* Transforms all failures of the handler effectfully except pure interruption
  • Loading branch information
SHSongs authored Nov 15, 2023
1 parent 1c8770b commit b3b4cb0
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
13 changes: 13 additions & 0 deletions zio-http/src/main/scala/zio/http/Handler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,19 @@ sealed trait Handler[-R, +Err, -In, +Out] { self =>
): Handler[R1, Err1, In, Out1] =
self.foldHandler(err => Handler.fromZIO(f(err)), Handler.succeed(_))

/**
* Transforms all failures of the handler effectfully except pure
* interruption.
*/
final def mapErrorCauseZIO[R1 <: R, Err1, Out1 >: Out](
f: Cause[Err] => ZIO[R1, Err1, Out1],
)(implicit trace: Trace): Handler[R1, Err1, In, Out1] =
self.foldCauseHandler(
err =>
if (err.isInterruptedOnly) Handler.failCause(err.asInstanceOf[Cause[Nothing]]) else Handler.fromZIO(f(err)),
Handler.succeed(_),
)

/**
* Returns a new handler where the error channel has been merged into the
* success channel to their common combined type.
Expand Down
26 changes: 26 additions & 0 deletions zio-http/src/main/scala/zio/http/Route.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,32 @@ sealed trait Route[-Env, +Err] { self =>
Handled(rpm.routePattern, handler2, location)
}

final def handleErrorCauseZIO(
f: Cause[Err] => ZIO[Any, Nothing, Response],
)(implicit trace: Trace): Route[Env, Nothing] =
self match {
case Provided(route, env) => Provided(route.handleErrorCauseZIO(f), env)
case Augmented(route, aspect) => Augmented(route.handleErrorCauseZIO(f), aspect)
case Handled(routePattern, handler, location) => Handled(routePattern, handler, location)

case Unhandled(rpm, handler, zippable, location) =>
val handler2: Handler[Env, Response, Request, Response] = {
val paramHandler =
Handler.fromFunctionZIO[(rpm.Context, Request)] { case (ctx, request) =>
rpm.routePattern.decode(request.method, request.path) match {
case Left(error) => ZIO.dieMessage(error)
case Right(value) =>
val params = rpm.zippable.zip(value, ctx)

handler(zippable.zip(params, request))
}
}
rpm.aspect.applyHandlerContext(paramHandler.mapErrorCauseZIO(f))
}

Handled(rpm.routePattern, handler2, location)
}

/**
* Determines if the route is defined for the specified request.
*/
Expand Down
3 changes: 3 additions & 0 deletions zio-http/src/main/scala/zio/http/Routes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ final class Routes[-Env, +Err] private (val routes: Chunk[zio.http.Route[Env, Er
def handleErrorCause(f: Cause[Err] => Response)(implicit trace: Trace): Routes[Env, Nothing] =
new Routes(routes.map(_.handleErrorCause(f)))

def handleErrorCauseZIO(f: Cause[Err] => ZIO[Any, Nothing, Response])(implicit trace: Trace): Routes[Env, Nothing] =
new Routes(routes.map(_.handleErrorCauseZIO(f)))

/**
* Returns new routes that have each been provided the specified environment,
* thus eliminating their requirement for any specific environment.
Expand Down
18 changes: 16 additions & 2 deletions zio-http/src/test/scala/zio/http/RouteSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package zio.http

import scala.collection.Seq

import zio._
import zio.test._

Expand Down Expand Up @@ -64,5 +62,21 @@ object RouteSpec extends ZIOHttpSpec {
} yield assertTrue(cnt == 2)
},
),
suite("error handle")(
test("handleErrorCauseZIO should execute a ZIO effect") {
val route = Method.GET / "endpoint" -> handler { (req: Request) => ZIO.fail(new Exception("hmm...")) }
for {
p <- zio.Promise.make[Exception, String]

errorHandled = route
.handleErrorCauseZIO(c => p.failCause(c).as(Response.internalServerError))

request = Request.get(URL.decode("/endpoint").toOption.get)
response <- errorHandled.toHttpApp.runZIO(request)
result <- p.await.catchAllCause(c => ZIO.succeed(c.prettyPrint))

} yield assertTrue(extractStatus(response) == Status.InternalServerError, result.contains("hmm..."))
},
),
)
}

0 comments on commit b3b4cb0

Please sign in to comment.