diff --git a/middleware/slash.go b/middleware/slash.go index 0492b334b..4188675b0 100644 --- a/middleware/slash.go +++ b/middleware/slash.go @@ -60,7 +60,7 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc // Redirect if config.RedirectCode != 0 { - return c.Redirect(config.RedirectCode, uri) + return c.Redirect(config.RedirectCode, sanitizeURI(uri)) } // Forward @@ -108,7 +108,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu // Redirect if config.RedirectCode != 0 { - return c.Redirect(config.RedirectCode, uri) + return c.Redirect(config.RedirectCode, sanitizeURI(uri)) } // Forward @@ -119,3 +119,12 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu } } } + +func sanitizeURI(uri string) string { + // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri + // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash + if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { + uri = "/" + strings.TrimLeft(uri, `/\`) + } + return uri +} diff --git a/middleware/slash_test.go b/middleware/slash_test.go index 2a8e9eeaa..ddb071045 100644 --- a/middleware/slash_test.go +++ b/middleware/slash_test.go @@ -9,88 +9,270 @@ import ( "github.com/stretchr/testify/assert" ) +func TestAddTrailingSlashWithConfig(t *testing.T) { + var testCases = []struct { + whenURL string + whenMethod string + expectPath string + expectLocation []string + expectStatus int + }{ + { + whenURL: "/add-slash", + whenMethod: http.MethodGet, + expectPath: "/add-slash", + expectLocation: []string{`/add-slash/`}, + }, + { + whenURL: "/add-slash?key=value", + whenMethod: http.MethodGet, + expectPath: "/add-slash", + expectLocation: []string{`/add-slash/?key=value`}, + }, + { + whenURL: "/", + whenMethod: http.MethodConnect, + expectPath: "/", + expectLocation: nil, + expectStatus: http.StatusOK, + }, + // cases for open redirect vulnerability + { + whenURL: "http://localhost:1323/%5Cexample.com", + expectPath: `/\example.com`, + expectLocation: []string{`/example.com/`}, + }, + { + whenURL: `http://localhost:1323/\example.com`, + expectPath: `/\example.com`, + expectLocation: []string{`/example.com/`}, + }, + { + whenURL: `http://localhost:1323/\\%5C////%5C\\\example.com`, + expectPath: `/\\\////\\\\example.com`, + expectLocation: []string{`/example.com/`}, + }, + { + whenURL: "http://localhost:1323//example.com", + expectPath: `//example.com`, + expectLocation: []string{`/example.com/`}, + }, + { + whenURL: "http://localhost:1323/%5C%5C", + expectPath: `/\\`, + expectLocation: []string{`/`}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + e := echo.New() + + mw := AddTrailingSlashWithConfig(TrailingSlashConfig{ + RedirectCode: http.StatusMovedPermanently, + }) + h := mw(func(c echo.Context) error { + return nil + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, tc.expectPath, req.URL.Path) + assert.Equal(t, tc.expectLocation, rec.Header()[echo.HeaderLocation]) + if tc.expectStatus == 0 { + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + } else { + assert.Equal(t, tc.expectStatus, rec.Code) + } + }) + } +} + func TestAddTrailingSlash(t *testing.T) { - is := assert.New(t) - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/add-slash", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := AddTrailingSlash()(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal("/add-slash/", req.URL.Path) - is.Equal("/add-slash/", req.RequestURI) - - // Method Connect must not fail: - req = httptest.NewRequest(http.MethodConnect, "", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = AddTrailingSlash()(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal("/", req.URL.Path) - is.Equal("/", req.RequestURI) - - // With config - req = httptest.NewRequest(http.MethodGet, "/add-slash?key=value", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = AddTrailingSlashWithConfig(TrailingSlashConfig{ - RedirectCode: http.StatusMovedPermanently, - })(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal(http.StatusMovedPermanently, rec.Code) - is.Equal("/add-slash/?key=value", rec.Header().Get(echo.HeaderLocation)) + var testCases = []struct { + whenURL string + whenMethod string + expectPath string + expectLocation []string + }{ + { + whenURL: "/add-slash", + whenMethod: http.MethodGet, + expectPath: "/add-slash/", + }, + { + whenURL: "/add-slash?key=value", + whenMethod: http.MethodGet, + expectPath: "/add-slash/", + }, + { + whenURL: "/", + whenMethod: http.MethodConnect, + expectPath: "/", + expectLocation: nil, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + e := echo.New() + + h := AddTrailingSlash()(func(c echo.Context) error { + return nil + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, tc.expectPath, req.URL.Path) + assert.Equal(t, []string(nil), rec.Header()[echo.HeaderLocation]) + assert.Equal(t, http.StatusOK, rec.Code) + }) + } +} + +func TestRemoveTrailingSlashWithConfig(t *testing.T) { + var testCases = []struct { + whenURL string + whenMethod string + expectPath string + expectLocation []string + expectStatus int + }{ + { + whenURL: "/remove-slash/", + whenMethod: http.MethodGet, + expectPath: "/remove-slash/", + expectLocation: []string{`/remove-slash`}, + }, + { + whenURL: "/remove-slash/?key=value", + whenMethod: http.MethodGet, + expectPath: "/remove-slash/", + expectLocation: []string{`/remove-slash?key=value`}, + }, + { + whenURL: "/", + whenMethod: http.MethodConnect, + expectPath: "/", + expectLocation: nil, + expectStatus: http.StatusOK, + }, + { + whenURL: "http://localhost", + whenMethod: http.MethodGet, + expectPath: "", + expectLocation: nil, + expectStatus: http.StatusOK, + }, + // cases for open redirect vulnerability + { + whenURL: "http://localhost:1323/%5Cexample.com/", + expectPath: `/\example.com/`, + expectLocation: []string{`/example.com`}, + }, + { + whenURL: `http://localhost:1323/\example.com/`, + expectPath: `/\example.com/`, + expectLocation: []string{`/example.com`}, + }, + { + whenURL: `http://localhost:1323/\\%5C////%5C\\\example.com/`, + expectPath: `/\\\////\\\\example.com/`, + expectLocation: []string{`/example.com`}, + }, + { + whenURL: "http://localhost:1323//example.com/", + expectPath: `//example.com/`, + expectLocation: []string{`/example.com`}, + }, + { + whenURL: "http://localhost:1323/%5C%5C/", + expectPath: `/\\/`, + expectLocation: []string{`/`}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + e := echo.New() + + mw := RemoveTrailingSlashWithConfig(TrailingSlashConfig{ + RedirectCode: http.StatusMovedPermanently, + }) + h := mw(func(c echo.Context) error { + return nil + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, tc.expectPath, req.URL.Path) + assert.Equal(t, tc.expectLocation, rec.Header()[echo.HeaderLocation]) + if tc.expectStatus == 0 { + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + } else { + assert.Equal(t, tc.expectStatus, rec.Code) + } + }) + } } func TestRemoveTrailingSlash(t *testing.T) { - is := assert.New(t) - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/remove-slash/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := RemoveTrailingSlash()(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal("/remove-slash", req.URL.Path) - is.Equal("/remove-slash", req.RequestURI) - - // Method Connect must not fail: - req = httptest.NewRequest(http.MethodConnect, "", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = RemoveTrailingSlash()(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal("", req.URL.Path) - is.Equal("", req.RequestURI) - - // With config - req = httptest.NewRequest(http.MethodGet, "/remove-slash/?key=value", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = RemoveTrailingSlashWithConfig(TrailingSlashConfig{ - RedirectCode: http.StatusMovedPermanently, - })(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal(http.StatusMovedPermanently, rec.Code) - is.Equal("/remove-slash?key=value", rec.Header().Get(echo.HeaderLocation)) - - // With bare URL - req = httptest.NewRequest(http.MethodGet, "http://localhost", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = RemoveTrailingSlash()(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal("", req.URL.Path) + var testCases = []struct { + whenURL string + whenMethod string + expectPath string + }{ + { + whenURL: "/remove-slash/", + whenMethod: http.MethodGet, + expectPath: "/remove-slash", + }, + { + whenURL: "/remove-slash/?key=value", + whenMethod: http.MethodGet, + expectPath: "/remove-slash", + }, + { + whenURL: "/", + whenMethod: http.MethodConnect, + expectPath: "/", + }, + { + whenURL: "http://localhost", + whenMethod: http.MethodGet, + expectPath: "", + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + e := echo.New() + + h := RemoveTrailingSlash()(func(c echo.Context) error { + return nil + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, tc.expectPath, req.URL.Path) + assert.Equal(t, []string(nil), rec.Header()[echo.HeaderLocation]) + assert.Equal(t, http.StatusOK, rec.Code) + }) + } }