Skip to content

Commit e80c353

Browse files
committed
fix redirect middleware panicing on short host name (fix #1811)
1 parent dec96f0 commit e80c353

File tree

2 files changed

+249
-59
lines changed

2 files changed

+249
-59
lines changed

middleware/redirect.go

+22-23
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package middleware
22

33
import (
44
"net/http"
5+
"strings"
56

67
"github.com/labstack/echo/v4"
78
)
@@ -40,11 +41,11 @@ func HTTPSRedirect() echo.MiddlewareFunc {
4041
// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config.
4142
// See `HTTPSRedirect()`.
4243
func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
43-
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
44-
if ok = scheme != "https"; ok {
45-
url = "https://" + host + uri
44+
return redirect(config, func(scheme, host, uri string) (bool, string) {
45+
if scheme != "https" {
46+
return true, "https://" + host + uri
4647
}
47-
return
48+
return false, ""
4849
})
4950
}
5051

@@ -59,11 +60,11 @@ func HTTPSWWWRedirect() echo.MiddlewareFunc {
5960
// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
6061
// See `HTTPSWWWRedirect()`.
6162
func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
62-
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
63-
if ok = scheme != "https" && host[:4] != www; ok {
64-
url = "https://www." + host + uri
63+
return redirect(config, func(scheme, host, uri string) (bool, string) {
64+
if scheme != "https" && !strings.HasPrefix(host, www) {
65+
return true, "https://www." + host + uri
6566
}
66-
return
67+
return false, ""
6768
})
6869
}
6970

@@ -79,13 +80,11 @@ func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
7980
// See `HTTPSNonWWWRedirect()`.
8081
func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
8182
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
82-
if ok = scheme != "https"; ok {
83-
if host[:4] == www {
84-
host = host[4:]
85-
}
86-
url = "https://" + host + uri
83+
if scheme != "https" {
84+
host = strings.TrimPrefix(host, www)
85+
return true, "https://" + host + uri
8786
}
88-
return
87+
return false, ""
8988
})
9089
}
9190

@@ -100,11 +99,11 @@ func WWWRedirect() echo.MiddlewareFunc {
10099
// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
101100
// See `WWWRedirect()`.
102101
func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
103-
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
104-
if ok = host[:4] != www; ok {
105-
url = scheme + "://www." + host + uri
102+
return redirect(config, func(scheme, host, uri string) (bool, string) {
103+
if !strings.HasPrefix(host, www) {
104+
return true, scheme + "://www." + host + uri
106105
}
107-
return
106+
return false, ""
108107
})
109108
}
110109

@@ -119,17 +118,17 @@ func NonWWWRedirect() echo.MiddlewareFunc {
119118
// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
120119
// See `NonWWWRedirect()`.
121120
func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
122-
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
123-
if ok = host[:4] == www; ok {
124-
url = scheme + "://" + host[4:] + uri
121+
return redirect(config, func(scheme, host, uri string) (bool, string) {
122+
if strings.HasPrefix(host, www) {
123+
return true, scheme + "://" + host[4:] + uri
125124
}
126-
return
125+
return false, ""
127126
})
128127
}
129128

130129
func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
131130
if config.Skipper == nil {
132-
config.Skipper = DefaultTrailingSlashConfig.Skipper
131+
config.Skipper = DefaultRedirectConfig.Skipper
133132
}
134133
if config.Code == 0 {
135134
config.Code = DefaultRedirectConfig.Code

middleware/redirect_test.go

+227-36
Original file line numberDiff line numberDiff line change
@@ -12,62 +12,253 @@ import (
1212
type middlewareGenerator func() echo.MiddlewareFunc
1313

1414
func TestRedirectHTTPSRedirect(t *testing.T) {
15-
res := redirectTest(HTTPSRedirect, "labstack.com", nil)
16-
17-
assert.Equal(t, http.StatusMovedPermanently, res.Code)
18-
assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation))
19-
}
15+
var testCases = []struct {
16+
whenHost string
17+
whenHeader http.Header
18+
expectLocation string
19+
expectStatusCode int
20+
}{
21+
{
22+
whenHost: "labstack.com",
23+
expectLocation: "https://labstack.com/",
24+
expectStatusCode: http.StatusMovedPermanently,
25+
},
26+
{
27+
whenHost: "labstack.com",
28+
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
29+
expectLocation: "",
30+
expectStatusCode: http.StatusOK,
31+
},
32+
}
2033

21-
func TestHTTPSRedirectBehindTLSTerminationProxy(t *testing.T) {
22-
header := http.Header{}
23-
header.Set(echo.HeaderXForwardedProto, "https")
24-
res := redirectTest(HTTPSRedirect, "labstack.com", header)
34+
for _, tc := range testCases {
35+
t.Run(tc.whenHost, func(t *testing.T) {
36+
res := redirectTest(HTTPSRedirect, tc.whenHost, tc.whenHeader)
2537

26-
assert.Equal(t, http.StatusOK, res.Code)
38+
assert.Equal(t, tc.expectStatusCode, res.Code)
39+
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
40+
})
41+
}
2742
}
2843

2944
func TestRedirectHTTPSWWWRedirect(t *testing.T) {
30-
res := redirectTest(HTTPSWWWRedirect, "labstack.com", nil)
31-
32-
assert.Equal(t, http.StatusMovedPermanently, res.Code)
33-
assert.Equal(t, "https://www.labstack.com/", res.Header().Get(echo.HeaderLocation))
34-
}
45+
var testCases = []struct {
46+
whenHost string
47+
whenHeader http.Header
48+
expectLocation string
49+
expectStatusCode int
50+
}{
51+
{
52+
whenHost: "labstack.com",
53+
expectLocation: "https://www.labstack.com/",
54+
expectStatusCode: http.StatusMovedPermanently,
55+
},
56+
{
57+
whenHost: "www.labstack.com",
58+
expectLocation: "",
59+
expectStatusCode: http.StatusOK,
60+
},
61+
{
62+
whenHost: "a.com",
63+
expectLocation: "https://www.a.com/",
64+
expectStatusCode: http.StatusMovedPermanently,
65+
},
66+
{
67+
whenHost: "ip",
68+
expectLocation: "https://www.ip/",
69+
expectStatusCode: http.StatusMovedPermanently,
70+
},
71+
{
72+
whenHost: "labstack.com",
73+
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
74+
expectLocation: "",
75+
expectStatusCode: http.StatusOK,
76+
},
77+
}
3578

36-
func TestRedirectHTTPSWWWRedirectBehindTLSTerminationProxy(t *testing.T) {
37-
header := http.Header{}
38-
header.Set(echo.HeaderXForwardedProto, "https")
39-
res := redirectTest(HTTPSWWWRedirect, "labstack.com", header)
79+
for _, tc := range testCases {
80+
t.Run(tc.whenHost, func(t *testing.T) {
81+
res := redirectTest(HTTPSWWWRedirect, tc.whenHost, tc.whenHeader)
4082

41-
assert.Equal(t, http.StatusOK, res.Code)
83+
assert.Equal(t, tc.expectStatusCode, res.Code)
84+
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
85+
})
86+
}
4287
}
4388

4489
func TestRedirectHTTPSNonWWWRedirect(t *testing.T) {
45-
res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", nil)
46-
47-
assert.Equal(t, http.StatusMovedPermanently, res.Code)
48-
assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation))
49-
}
90+
var testCases = []struct {
91+
whenHost string
92+
whenHeader http.Header
93+
expectLocation string
94+
expectStatusCode int
95+
}{
96+
{
97+
whenHost: "www.labstack.com",
98+
expectLocation: "https://labstack.com/",
99+
expectStatusCode: http.StatusMovedPermanently,
100+
},
101+
{
102+
whenHost: "a.com",
103+
expectLocation: "https://a.com/",
104+
expectStatusCode: http.StatusMovedPermanently,
105+
},
106+
{
107+
whenHost: "ip",
108+
expectLocation: "https://ip/",
109+
expectStatusCode: http.StatusMovedPermanently,
110+
},
111+
{
112+
whenHost: "www.labstack.com",
113+
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
114+
expectLocation: "",
115+
expectStatusCode: http.StatusOK,
116+
},
117+
}
50118

51-
func TestRedirectHTTPSNonWWWRedirectBehindTLSTerminationProxy(t *testing.T) {
52-
header := http.Header{}
53-
header.Set(echo.HeaderXForwardedProto, "https")
54-
res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", header)
119+
for _, tc := range testCases {
120+
t.Run(tc.whenHost, func(t *testing.T) {
121+
res := redirectTest(HTTPSNonWWWRedirect, tc.whenHost, tc.whenHeader)
55122

56-
assert.Equal(t, http.StatusOK, res.Code)
123+
assert.Equal(t, tc.expectStatusCode, res.Code)
124+
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
125+
})
126+
}
57127
}
58128

59129
func TestRedirectWWWRedirect(t *testing.T) {
60-
res := redirectTest(WWWRedirect, "labstack.com", nil)
130+
var testCases = []struct {
131+
whenHost string
132+
whenHeader http.Header
133+
expectLocation string
134+
expectStatusCode int
135+
}{
136+
{
137+
whenHost: "labstack.com",
138+
expectLocation: "http://www.labstack.com/",
139+
expectStatusCode: http.StatusMovedPermanently,
140+
},
141+
{
142+
whenHost: "a.com",
143+
expectLocation: "http://www.a.com/",
144+
expectStatusCode: http.StatusMovedPermanently,
145+
},
146+
{
147+
whenHost: "ip",
148+
expectLocation: "http://www.ip/",
149+
expectStatusCode: http.StatusMovedPermanently,
150+
},
151+
{
152+
whenHost: "a.com",
153+
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
154+
expectLocation: "https://www.a.com/",
155+
expectStatusCode: http.StatusMovedPermanently,
156+
},
157+
{
158+
whenHost: "www.ip",
159+
expectLocation: "",
160+
expectStatusCode: http.StatusOK,
161+
},
162+
}
163+
164+
for _, tc := range testCases {
165+
t.Run(tc.whenHost, func(t *testing.T) {
166+
res := redirectTest(WWWRedirect, tc.whenHost, tc.whenHeader)
61167

62-
assert.Equal(t, http.StatusMovedPermanently, res.Code)
63-
assert.Equal(t, "http://www.labstack.com/", res.Header().Get(echo.HeaderLocation))
168+
assert.Equal(t, tc.expectStatusCode, res.Code)
169+
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
170+
})
171+
}
64172
}
65173

66174
func TestRedirectNonWWWRedirect(t *testing.T) {
67-
res := redirectTest(NonWWWRedirect, "www.labstack.com", nil)
175+
var testCases = []struct {
176+
whenHost string
177+
whenHeader http.Header
178+
expectLocation string
179+
expectStatusCode int
180+
}{
181+
{
182+
whenHost: "www.labstack.com",
183+
expectLocation: "http://labstack.com/",
184+
expectStatusCode: http.StatusMovedPermanently,
185+
},
186+
{
187+
whenHost: "www.a.com",
188+
expectLocation: "http://a.com/",
189+
expectStatusCode: http.StatusMovedPermanently,
190+
},
191+
{
192+
whenHost: "www.a.com",
193+
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
194+
expectLocation: "https://a.com/",
195+
expectStatusCode: http.StatusMovedPermanently,
196+
},
197+
{
198+
whenHost: "ip",
199+
expectLocation: "",
200+
expectStatusCode: http.StatusOK,
201+
},
202+
}
203+
204+
for _, tc := range testCases {
205+
t.Run(tc.whenHost, func(t *testing.T) {
206+
res := redirectTest(NonWWWRedirect, tc.whenHost, tc.whenHeader)
207+
208+
assert.Equal(t, tc.expectStatusCode, res.Code)
209+
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
210+
})
211+
}
212+
}
213+
214+
func TestNonWWWRedirectWithConfig(t *testing.T) {
215+
var testCases = []struct {
216+
name string
217+
givenCode int
218+
givenSkipFunc func(c echo.Context) bool
219+
whenHost string
220+
whenHeader http.Header
221+
expectLocation string
222+
expectStatusCode int
223+
}{
224+
{
225+
name: "usual redirect",
226+
whenHost: "www.labstack.com",
227+
expectLocation: "http://labstack.com/",
228+
expectStatusCode: http.StatusMovedPermanently,
229+
},
230+
{
231+
name: "redirect is skipped",
232+
givenSkipFunc: func(c echo.Context) bool {
233+
return true // skip always
234+
},
235+
whenHost: "www.labstack.com",
236+
expectLocation: "",
237+
expectStatusCode: http.StatusOK,
238+
},
239+
{
240+
name: "redirect with custom status code",
241+
givenCode: http.StatusSeeOther,
242+
whenHost: "www.labstack.com",
243+
expectLocation: "http://labstack.com/",
244+
expectStatusCode: http.StatusSeeOther,
245+
},
246+
}
247+
248+
for _, tc := range testCases {
249+
t.Run(tc.whenHost, func(t *testing.T) {
250+
middleware := func() echo.MiddlewareFunc {
251+
return NonWWWRedirectWithConfig(RedirectConfig{
252+
Skipper: tc.givenSkipFunc,
253+
Code: tc.givenCode,
254+
})
255+
}
256+
res := redirectTest(middleware, tc.whenHost, tc.whenHeader)
68257

69-
assert.Equal(t, http.StatusMovedPermanently, res.Code)
70-
assert.Equal(t, "http://labstack.com/", res.Header().Get(echo.HeaderLocation))
258+
assert.Equal(t, tc.expectStatusCode, res.Code)
259+
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
260+
})
261+
}
71262
}
72263

73264
func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder {

0 commit comments

Comments
 (0)