Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(timeout): unify and enhance timeout middleware #3275

Merged
merged 7 commits into from
Jan 8, 2025
59 changes: 47 additions & 12 deletions middleware/timeout/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,58 @@ import (
"github.com/gofiber/fiber/v3"
)

// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response.
func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler {
// New enforces a timeout for each incoming request. If the timeout expires or
// any of the specified errors occur, fiber.ErrRequestTimeout is returned.
func New(h fiber.Handler, timeout time.Duration, tErrs ...error) fiber.Handler {
return func(ctx fiber.Ctx) error {
timeoutContext, cancel := context.WithTimeout(ctx.Context(), t)
// If timeout <= 0, skip context.WithTimeout and run the handler as-is.
if timeout <= 0 {
return runHandler(ctx, h, tErrs)
}

// Create a context with the specified timeout; any operation exceeding
// this deadline will be canceled automatically.
timeoutContext, cancel := context.WithTimeout(ctx.Context(), timeout)
defer cancel()

// Attach the timeout-bound context to the current Fiber context.
ctx.SetContext(timeoutContext)
if err := h(ctx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {

// Execute the wrapped handler synchronously.
err := h(ctx)

// If the context has timed out, return a request timeout error.
if timeoutContext.Err() != nil && errors.Is(timeoutContext.Err(), context.DeadlineExceeded) {
return fiber.ErrRequestTimeout
}

// If the handler returned an error, check whether it's a deadline exceeded
// error or any other listed timeout-triggering error.
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isCustomError(err, tErrs) {
return fiber.ErrRequestTimeout
}
for i := range tErrs {
if errors.Is(err, tErrs[i]) {
return fiber.ErrRequestTimeout
}
}
return err
}
return nil
return err
}
}

// runHandler executes the handler and returns fiber.ErrRequestTimeout if it
// sees a deadline exceeded error or one of the custom "timeout-like" errors.
func runHandler(c fiber.Ctx, h fiber.Handler, tErrs []error) error {
err := h(c)
if err != nil && (errors.Is(err, context.DeadlineExceeded) || isCustomError(err, tErrs)) {
return fiber.ErrRequestTimeout
}
return err
}

// isCustomError checks whether err matches any error in errList using errors.Is.
func isCustomError(err error, errList []error) bool {
for _, e := range errList {
if errors.Is(err, e) {
return true
}
}
return false
}
163 changes: 100 additions & 63 deletions middleware/timeout/timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,114 @@ import (
"github.com/stretchr/testify/require"
)

// go test -run Test_WithContextTimeout
func Test_WithContextTimeout(t *testing.T) {
t.Parallel()
// fiber instance
app := fiber.New()
h := New(func(c fiber.Ctx) error {
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
require.NoError(t, err)
if err := sleepWithContext(c.Context(), sleepTime, context.DeadlineExceeded); err != nil {
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
}
var (
// Custom error that we treat like a timeout when returned by the handler.
errCustomTimeout = errors.New("custom timeout error")

// Some unrelated error that should NOT trigger a request timeout.
errUnrelated = errors.New("unmatched error")
)

// sleepWithContext simulates a task that takes `d` time, but returns `te` if the context is canceled.
func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
timer := time.NewTimer(d)
defer timer.Stop() // Clean up the timer

select {
case <-ctx.Done():
return te
case <-timer.C:
return nil
}, 100*time.Millisecond)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
}

var ErrFooTimeOut = errors.New("foo context canceled")
// TestTimeout_Success tests a handler that completes within the allotted timeout.
func TestTimeout_Success(t *testing.T) {
app := fiber.New()

// Our middleware wraps a handler that sleeps for 10ms, well under the 50ms limit.
app.Get("/fast", New(func(c fiber.Ctx) error {
// Simulate some work
if err := sleepWithContext(c.Context(), 10*time.Millisecond, context.DeadlineExceeded); err != nil {
return err
}
return c.SendString("OK")
}, 50*time.Millisecond))

// go test -run Test_WithContextTimeoutWithCustomError
func Test_WithContextTimeoutWithCustomError(t *testing.T) {
t.Parallel()
// fiber instance
req := httptest.NewRequest(fiber.MethodGet, "/fast", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK for fast requests")
}

// TestTimeout_Exceeded tests a handler that exceeds the provided timeout.
func TestTimeout_Exceeded(t *testing.T) {
app := fiber.New()
h := New(func(c fiber.Ctx) error {
sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms")
require.NoError(t, err)
if err := sleepWithContext(c.Context(), sleepTime, ErrFooTimeOut); err != nil {
return fmt.Errorf("%w: execution error", err)

// This handler sleeps 200ms, exceeding the 100ms limit.
app.Get("/slow", New(func(c fiber.Ctx) error {
if err := sleepWithContext(c.Context(), 200*time.Millisecond, context.DeadlineExceeded); err != nil {
return err
}
return nil
}, 100*time.Millisecond, ErrFooTimeOut)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
return c.SendString("Should never get here")
}, 100*time.Millisecond))

req := httptest.NewRequest(fiber.MethodGet, "/slow", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 Request Timeout")
}

func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
timer := time.NewTimer(d)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
// TestTimeout_CustomError tests that returning a user-defined error is also treated as a timeout.
func TestTimeout_CustomError(t *testing.T) {
app := fiber.New()

// This handler sleeps 50ms and returns errCustomTimeout if canceled.
app.Get("/custom", New(func(c fiber.Ctx) error {
// Sleep might time out, or might return early. If the context is canceled,
// we treat errCustomTimeout as a 'timeout-like' condition.
if err := sleepWithContext(c.Context(), 200*time.Millisecond, errCustomTimeout); err != nil {
return fmt.Errorf("wrapped: %w", err)
}
return te
case <-timer.C:
}
return nil
return c.SendString("Should never get here")
}, 100*time.Millisecond, errCustomTimeout))

req := httptest.NewRequest(fiber.MethodGet, "/custom", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 for custom timeout error")
}

// TestTimeout_UnmatchedError checks that if the handler returns an error
// that is neither a deadline exceeded nor a custom 'timeout' error, it is
// propagated as a regular 500 (internal server error).
func TestTimeout_UnmatchedError(t *testing.T) {
app := fiber.New()

app.Get("/unmatched", New(func(_ fiber.Ctx) error {
return errUnrelated // Not in the custom error list
}, 100*time.Millisecond, errCustomTimeout))

req := httptest.NewRequest(fiber.MethodGet, "/unmatched", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode,
"Expected 500 because the error is not recognized as a timeout error")
}

// TestTimeout_ZeroDuration tests the edge case where the timeout is set to zero.
// Usually this means the request can never exceed a 'deadline' – effectively no timeout.
func TestTimeout_ZeroDuration(t *testing.T) {
app := fiber.New()

app.Get("/zero", New(func(c fiber.Ctx) error {
// Sleep 50ms, but there's no real 'deadline' since zero-timeout.
time.Sleep(50 * time.Millisecond)
return c.SendString("No timeout used")
}, 0))

req := httptest.NewRequest(fiber.MethodGet, "/zero", nil)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK with zero timeout")
}
Loading