Skip to content

Commit

Permalink
refactor(timeout): unify and enhance timeout middleware (#3275)
Browse files Browse the repository at this point in the history
* feat(timeout): unify and enhance timeout middleware

- Combine classic context-based timeout with a Goroutine + channel approach
- Support custom error list without additional parameters
- Return fiber.ErrRequestTimeout for timeouts or listed errors

* feat(timeout): unify and enhance timeout middleware

- Combine classic context-based timeout with a Goroutine + channel approach
- Support custom error list without additional parameters
- Return fiber.ErrRequestTimeout for timeouts or listed errors

* refactor(timeout): remove goroutine-based logic and improve documentation

- Switch to a synchronous approach to avoid data races with fasthttp context
- Enhance error handling for deadline and custom errors
- Update comments for clarity and maintainability

* refactor(timeout): add more test cases and handle zero duration case

* refactor(timeout): add more test cases and handle zero duration case

* refactor(timeout): add more test cases and handle zero duration case

---------

Co-authored-by: Juan Calderon-Perez <[email protected]>
  • Loading branch information
ReneWerner87 and gaby authored Jan 8, 2025
1 parent 86d72bb commit bc37f20
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 76 deletions.
57 changes: 43 additions & 14 deletions middleware/timeout/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,52 @@ import (
"github.com/gofiber/fiber/v3"
)

// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response.
func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler {
// New enforces a timeout for each incoming request. If the timeout expires or
// any of the specified errors occur, fiber.ErrRequestTimeout is returned.
func New(h fiber.Handler, timeout time.Duration, tErrs ...error) fiber.Handler {
return func(ctx fiber.Ctx) error {
timeoutContext, cancel := context.WithTimeout(ctx.Context(), t)
// If timeout <= 0, skip context.WithTimeout and run the handler as-is.
if timeout <= 0 {
return runHandler(ctx, h, tErrs)
}

// Create a context with the specified timeout; any operation exceeding
// this deadline will be canceled automatically.
timeoutContext, cancel := context.WithTimeout(ctx.Context(), timeout)
defer cancel()

// Replace the default Fiber context with our timeout-bound context.
ctx.SetContext(timeoutContext)
if err := h(ctx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return fiber.ErrRequestTimeout
}
for i := range tErrs {
if errors.Is(err, tErrs[i]) {
return fiber.ErrRequestTimeout
}
}
return err

// Run the handler and check for relevant errors.
err := runHandler(ctx, h, tErrs)

// If the context actually timed out, return a timeout error.
if errors.Is(timeoutContext.Err(), context.DeadlineExceeded) {
return fiber.ErrRequestTimeout
}
return err
}
}

// runHandler executes the handler and returns fiber.ErrRequestTimeout if it
// sees a deadline exceeded error or one of the custom "timeout-like" errors.
func runHandler(c fiber.Ctx, h fiber.Handler, tErrs []error) error {
// Execute the wrapped handler synchronously.
err := h(c)
// If the context has timed out, return a request timeout error.
if err != nil && (errors.Is(err, context.DeadlineExceeded) || isCustomError(err, tErrs)) {
return fiber.ErrRequestTimeout
}
return err
}

// isCustomError checks whether err matches any error in errList using errors.Is.
func isCustomError(err error, errList []error) bool {
for _, e := range errList {
if errors.Is(err, e) {
return true
}
return nil
}
return false
}
166 changes: 104 additions & 62 deletions middleware/timeout/timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,119 @@ import (
"github.com/stretchr/testify/require"
)

// go test -run Test_WithContextTimeout
func Test_WithContextTimeout(t *testing.T) {
t.Parallel()
// fiber instance
app := fiber.New()
h := New(func(c fiber.Ctx) error {
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
require.NoError(t, err)
if err := sleepWithContext(c.Context(), sleepTime, context.DeadlineExceeded); err != nil {
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
}
var (
// Custom error that we treat like a timeout when returned by the handler.
errCustomTimeout = errors.New("custom timeout error")

// Some unrelated error that should NOT trigger a request timeout.
errUnrelated = errors.New("unmatched error")
)

// sleepWithContext simulates a task that takes `d` time, but returns `te` if the context is canceled.
func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
timer := time.NewTimer(d)
defer timer.Stop() // Clean up the timer

select {
case <-ctx.Done():
return te
case <-timer.C:
return nil
}, 100*time.Millisecond)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
}

var ErrFooTimeOut = errors.New("foo context canceled")
// TestTimeout_Success tests a handler that completes within the allotted timeout.
func TestTimeout_Success(t *testing.T) {
t.Parallel()
app := fiber.New()

// Our middleware wraps a handler that sleeps for 10ms, well under the 50ms limit.
app.Get("/fast", New(func(c fiber.Ctx) error {
// Simulate some work
if err := sleepWithContext(c.Context(), 10*time.Millisecond, context.DeadlineExceeded); err != nil {
return err
}
return c.SendString("OK")
}, 50*time.Millisecond))

req := httptest.NewRequest(fiber.MethodGet, "/fast", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK for fast requests")
}

// go test -run Test_WithContextTimeoutWithCustomError
func Test_WithContextTimeoutWithCustomError(t *testing.T) {
// TestTimeout_Exceeded tests a handler that exceeds the provided timeout.
func TestTimeout_Exceeded(t *testing.T) {
t.Parallel()
// fiber instance
app := fiber.New()
h := New(func(c fiber.Ctx) error {
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
require.NoError(t, err)
if err := sleepWithContext(c.Context(), sleepTime, ErrFooTimeOut); err != nil {
return fmt.Errorf("%w: execution error", err)

// This handler sleeps 200ms, exceeding the 100ms limit.
app.Get("/slow", New(func(c fiber.Ctx) error {
if err := sleepWithContext(c.Context(), 200*time.Millisecond, context.DeadlineExceeded); err != nil {
return err
}
return nil
}, 100*time.Millisecond, ErrFooTimeOut)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
return c.SendString("Should never get here")
}, 100*time.Millisecond))

req := httptest.NewRequest(fiber.MethodGet, "/slow", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 Request Timeout")
}

func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
timer := time.NewTimer(d)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
// TestTimeout_CustomError tests that returning a user-defined error is also treated as a timeout.
func TestTimeout_CustomError(t *testing.T) {
t.Parallel()
app := fiber.New()

// This handler sleeps 50ms and returns errCustomTimeout if canceled.
app.Get("/custom", New(func(c fiber.Ctx) error {
// Sleep might time out, or might return early. If the context is canceled,
// we treat errCustomTimeout as a 'timeout-like' condition.
if err := sleepWithContext(c.Context(), 200*time.Millisecond, errCustomTimeout); err != nil {
return fmt.Errorf("wrapped: %w", err)
}
return te
case <-timer.C:
}
return nil
return c.SendString("Should never get here")
}, 100*time.Millisecond, errCustomTimeout))

req := httptest.NewRequest(fiber.MethodGet, "/custom", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 for custom timeout error")
}

// TestTimeout_UnmatchedError checks that if the handler returns an error
// that is neither a deadline exceeded nor a custom 'timeout' error, it is
// propagated as a regular 500 (internal server error).
func TestTimeout_UnmatchedError(t *testing.T) {
t.Parallel()
app := fiber.New()

app.Get("/unmatched", New(func(_ fiber.Ctx) error {
return errUnrelated // Not in the custom error list
}, 100*time.Millisecond, errCustomTimeout))

req := httptest.NewRequest(fiber.MethodGet, "/unmatched", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode,
"Expected 500 because the error is not recognized as a timeout error")
}

// TestTimeout_ZeroDuration tests the edge case where the timeout is set to zero.
// Usually this means the request can never exceed a 'deadline' – effectively no timeout.
func TestTimeout_ZeroDuration(t *testing.T) {
t.Parallel()
app := fiber.New()

app.Get("/zero", New(func(c fiber.Ctx) error {
// Sleep 50ms, but there's no real 'deadline' since zero-timeout.
time.Sleep(50 * time.Millisecond)
return c.SendString("No timeout used")
}, 0))

req := httptest.NewRequest(fiber.MethodGet, "/zero", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK with zero timeout")
}

1 comment on commit bc37f20

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.50.

Benchmark suite Current: bc37f20 Previous: e04f815 Ratio
Benchmark_Ctx_Send 6.813 ns/op 0 B/op 0 allocs/op 4.343 ns/op 0 B/op 0 allocs/op 1.57
Benchmark_Ctx_Send - ns/op 6.813 ns/op 4.343 ns/op 1.57
Benchmark_Utils_GetOffer/1_parameter 223.4 ns/op 0 B/op 0 allocs/op 136.5 ns/op 0 B/op 0 allocs/op 1.64
Benchmark_Utils_GetOffer/1_parameter - ns/op 223.4 ns/op 136.5 ns/op 1.64
`Benchmark_RoutePatternMatch//api/:param/fixedEnd_ not_match _/api/abc/def/fixedEnd - allocs/op` 14 allocs/op
Benchmark_Middleware_BasicAuth - B/op 80 B/op 48 B/op 1.67
Benchmark_Middleware_BasicAuth - allocs/op 5 allocs/op 3 allocs/op 1.67
Benchmark_Middleware_BasicAuth_Upper - B/op 80 B/op 48 B/op 1.67
Benchmark_Middleware_BasicAuth_Upper - allocs/op 5 allocs/op 3 allocs/op 1.67
Benchmark_CORS_NewHandler - B/op 16 B/op 0 B/op +∞
Benchmark_CORS_NewHandler - allocs/op 1 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerSingleOrigin - B/op 16 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerSingleOrigin - allocs/op 1 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerPreflight - B/op 104 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerPreflight - allocs/op 5 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerPreflightSingleOrigin - B/op 104 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerPreflightSingleOrigin - allocs/op 5 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerPreflightWildcard - B/op 104 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerPreflightWildcard - allocs/op 5 allocs/op 0 allocs/op +∞
Benchmark_Middleware_CSRF_GenerateToken - B/op 519 B/op 341 B/op 1.52
Benchmark_Middleware_CSRF_GenerateToken - allocs/op 10 allocs/op 6 allocs/op 1.67

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.