Skip to content

Commit 4e2e46b

Browse files
author
Ilija Matoski
committed
Timeout middleware implementation, and added go build tag so it build from go1.13 only
1 parent f718079 commit 4e2e46b

File tree

2 files changed

+258
-0
lines changed

2 files changed

+258
-0
lines changed

middleware/timeout.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// +build go1.13
2+
3+
package middleware
4+
5+
import (
6+
"context"
7+
"github.com/labstack/echo/v4"
8+
"time"
9+
)
10+
11+
type (
12+
// TimeoutConfig defines the config for Timeout middleware.
13+
TimeoutConfig struct {
14+
// Skipper defines a function to skip middleware.
15+
Skipper Skipper
16+
// ErrorHandler defines a function which is executed for a timeout
17+
// It can be used to define a custom timeout error
18+
ErrorHandler TimeoutErrorHandlerWithContext
19+
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
20+
Timeout time.Duration
21+
}
22+
23+
// TimeoutErrorHandlerWithContext is an error handler that is used with the timeout middleware so we can
24+
// handle the error as we see fit
25+
TimeoutErrorHandlerWithContext func(error, echo.Context) error
26+
)
27+
28+
var (
29+
// DefaultTimeoutConfig is the default Timeout middleware config.
30+
DefaultTimeoutConfig = TimeoutConfig{
31+
Skipper: DefaultSkipper,
32+
Timeout: 0,
33+
ErrorHandler: nil,
34+
}
35+
)
36+
37+
// Timeout returns a middleware which recovers from panics anywhere in the chain
38+
// and handles the control to the centralized HTTPErrorHandler.
39+
func Timeout() echo.MiddlewareFunc {
40+
return TimeoutWithConfig(DefaultTimeoutConfig)
41+
}
42+
43+
// TimeoutWithConfig returns a Timeout middleware with config.
44+
// See: `Timeout()`.
45+
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
46+
// Defaults
47+
if config.Skipper == nil {
48+
config.Skipper = DefaultTimeoutConfig.Skipper
49+
}
50+
51+
return func(next echo.HandlerFunc) echo.HandlerFunc {
52+
return func(c echo.Context) error {
53+
if config.Skipper(c) || config.Timeout == 0 {
54+
return next(c)
55+
}
56+
57+
ctx, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
58+
defer cancel()
59+
60+
// this does a deep clone of the context, wondering if there is a better way to do this?
61+
c.SetRequest(c.Request().Clone(ctx))
62+
63+
done := make(chan error, 1)
64+
go func() {
65+
// This goroutine will keep running even if this middleware times out and
66+
// will be stopped when ctx.Done() is called down the next(c) call chain
67+
done <- next(c)
68+
}()
69+
70+
select {
71+
case <-ctx.Done():
72+
if config.ErrorHandler != nil {
73+
return config.ErrorHandler(ctx.Err(), c)
74+
}
75+
return ctx.Err()
76+
case err := <-done:
77+
return err
78+
}
79+
}
80+
}
81+
}

middleware/timeout_test.go

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
// +build go1.13
2+
3+
package middleware
4+
5+
import (
6+
"context"
7+
"errors"
8+
"github.com/labstack/echo/v4"
9+
"github.com/stretchr/testify/assert"
10+
"net/http"
11+
"net/http/httptest"
12+
"net/url"
13+
"reflect"
14+
"strings"
15+
"testing"
16+
"time"
17+
)
18+
19+
func TestTimeoutSkipper(t *testing.T) {
20+
t.Parallel()
21+
m := TimeoutWithConfig(TimeoutConfig{
22+
Skipper: func(context echo.Context) bool {
23+
return true
24+
},
25+
})
26+
27+
req := httptest.NewRequest(http.MethodGet, "/", nil)
28+
rec := httptest.NewRecorder()
29+
30+
e := echo.New()
31+
c := e.NewContext(req, rec)
32+
33+
err := m(func(c echo.Context) error {
34+
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
35+
return nil
36+
})(c)
37+
38+
assert.NoError(t, err)
39+
}
40+
41+
func TestTimeoutWithTimeout0(t *testing.T) {
42+
t.Parallel()
43+
m := TimeoutWithConfig(TimeoutConfig{
44+
Timeout: 0,
45+
})
46+
47+
req := httptest.NewRequest(http.MethodGet, "/", nil)
48+
rec := httptest.NewRecorder()
49+
50+
e := echo.New()
51+
c := e.NewContext(req, rec)
52+
53+
err := m(func(c echo.Context) error {
54+
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
55+
return nil
56+
})(c)
57+
58+
assert.NoError(t, err)
59+
}
60+
61+
func TestTimeoutIsCancelable(t *testing.T) {
62+
t.Parallel()
63+
m := TimeoutWithConfig(TimeoutConfig{
64+
Timeout: time.Minute,
65+
})
66+
67+
req := httptest.NewRequest(http.MethodGet, "/", nil)
68+
rec := httptest.NewRecorder()
69+
70+
e := echo.New()
71+
c := e.NewContext(req, rec)
72+
73+
err := m(func(c echo.Context) error {
74+
assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
75+
return nil
76+
})(c)
77+
78+
assert.NoError(t, err)
79+
}
80+
81+
func TestTimeoutErrorOutInHandler(t *testing.T) {
82+
t.Parallel()
83+
m := Timeout()
84+
85+
req := httptest.NewRequest(http.MethodGet, "/", nil)
86+
rec := httptest.NewRecorder()
87+
88+
e := echo.New()
89+
c := e.NewContext(req, rec)
90+
91+
err := m(func(c echo.Context) error {
92+
return errors.New("err")
93+
})(c)
94+
95+
assert.Error(t, err)
96+
}
97+
98+
func TestTimeoutTimesOutAfterPredefinedTimeoutWithErrorHandler(t *testing.T) {
99+
t.Parallel()
100+
m := TimeoutWithConfig(TimeoutConfig{
101+
Timeout: time.Second,
102+
ErrorHandler: func(err error, e echo.Context) error {
103+
assert.EqualError(t, err, context.DeadlineExceeded.Error())
104+
return errors.New("err")
105+
},
106+
})
107+
108+
req := httptest.NewRequest(http.MethodGet, "/", nil)
109+
rec := httptest.NewRecorder()
110+
111+
e := echo.New()
112+
c := e.NewContext(req, rec)
113+
114+
err := m(func(c echo.Context) error {
115+
time.Sleep(time.Minute)
116+
return nil
117+
})(c)
118+
119+
assert.EqualError(t, err, errors.New("err").Error())
120+
}
121+
122+
func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) {
123+
t.Parallel()
124+
m := TimeoutWithConfig(TimeoutConfig{
125+
Timeout: time.Second,
126+
})
127+
128+
req := httptest.NewRequest(http.MethodGet, "/", nil)
129+
rec := httptest.NewRecorder()
130+
131+
e := echo.New()
132+
c := e.NewContext(req, rec)
133+
134+
err := m(func(c echo.Context) error {
135+
time.Sleep(time.Minute)
136+
return nil
137+
})(c)
138+
139+
assert.EqualError(t, err, context.DeadlineExceeded.Error())
140+
}
141+
142+
func TestTimeoutTestRequestClone(t *testing.T) {
143+
t.Parallel()
144+
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
145+
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
146+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
147+
rec := httptest.NewRecorder()
148+
149+
m := TimeoutWithConfig(TimeoutConfig{
150+
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
151+
Timeout: time.Second,
152+
})
153+
154+
e := echo.New()
155+
c := e.NewContext(req, rec)
156+
157+
err := m(func(c echo.Context) error {
158+
// Cookie test
159+
cookie, err := c.Request().Cookie("cookie")
160+
if assert.NoError(t, err) {
161+
assert.EqualValues(t, "cookie", cookie.Name)
162+
assert.EqualValues(t, "value", cookie.Value)
163+
}
164+
165+
// Form values
166+
if assert.NoError(t, c.Request().ParseForm()) {
167+
assert.EqualValues(t, "value", c.Request().FormValue("form"))
168+
}
169+
170+
// Query string
171+
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
172+
return nil
173+
})(c)
174+
175+
assert.NoError(t, err)
176+
177+
}

0 commit comments

Comments
 (0)