diff --git a/middleware/compress.go b/middleware/compress.go index e4f9fc514..6ae197453 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -59,7 +59,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { config.Level = DefaultGzipConfig.Level } - pool := gzipPool(config) + pool := gzipCompressPool(config) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -133,7 +133,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { return http.ErrNotSupported } -func gzipPool(config GzipConfig) sync.Pool { +func gzipCompressPool(config GzipConfig) sync.Pool { return sync.Pool{ New: func() interface{} { w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) diff --git a/middleware/decompress.go b/middleware/decompress.go index 99eaf066d..c046359a2 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -3,9 +3,12 @@ package middleware import ( "bytes" "compress/gzip" - "github.com/labstack/echo/v4" "io" "io/ioutil" + "net/http" + "sync" + + "github.com/labstack/echo/v4" ) type ( @@ -13,17 +16,55 @@ type ( DecompressConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper + + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor } ) //GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" +// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers +type Decompressor interface { + gzipDecompressPool() sync.Pool +} + var ( //DefaultDecompressConfig defines the config for decompress middleware - DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper} + DefaultDecompressConfig = DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &DefaultGzipDecompressPool{}, + } ) +// DefaultGzipDecompressPool is the default implementation of Decompressor interface +type DefaultGzipDecompressPool struct { +} + +func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + // create with an empty reader (but with GZIP header) + w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed) + if err != nil { + return err + } + + b := new(bytes.Buffer) + w.Reset(b) + w.Flush() + w.Close() + + r, err := gzip.NewReader(bytes.NewReader(b.Bytes())) + if err != nil { + return err + } + return r + }, + } +} + //Decompress decompresses request body based if content encoding type is set to "gzip" with default config func Decompress() echo.MiddlewareFunc { return DecompressWithConfig(DefaultDecompressConfig) @@ -31,25 +72,46 @@ func Decompress() echo.MiddlewareFunc { //DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultGzipConfig.Skipper + } + if config.GzipDecompressPool == nil { + config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + pool := config.GzipDecompressPool.gzipDecompressPool() return func(c echo.Context) error { if config.Skipper(c) { return next(c) } switch c.Request().Header.Get(echo.HeaderContentEncoding) { case GZIPEncoding: - gr, err := gzip.NewReader(c.Request().Body) - if err != nil { + b := c.Request().Body + + i := pool.Get() + gr, ok := i.(*gzip.Reader) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) + } + + if err := gr.Reset(b); err != nil { + pool.Put(gr) if err == io.EOF { //ignore if body is empty return next(c) } return err } - defer gr.Close() var buf bytes.Buffer io.Copy(&buf, gr) + + gr.Close() + pool.Put(gr) + + b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here + r := ioutil.NopCloser(&buf) - defer r.Close() c.Request().Body = r } return next(c) diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 772c14f6d..51fa6b0f1 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -3,10 +3,12 @@ package middleware import ( "bytes" "compress/gzip" + "errors" "io/ioutil" "net/http" "net/http/httptest" "strings" + "sync" "testing" "github.com/labstack/echo/v4" @@ -43,6 +45,35 @@ func TestDecompress(t *testing.T) { assert.Equal(body, string(b)) } +func TestDecompressDefaultConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + h(c) + + assert := assert.New(t) + assert.Equal("test", rec.Body.String()) + + // Decompress + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) + assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(err) + assert.Equal(body, string(b)) +} + func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { e := echo.New() body := `{"name":"echo"}` @@ -108,6 +139,36 @@ func TestDecompressSkipper(t *testing.T) { assert.Equal(t, body, string(reqBody)) } +type TestDecompressPoolWithError struct { +} + +func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + return errors.New("pool error") + }, + } +} + +func TestDecompressPoolError(t *testing.T) { + e := echo.New() + e.Use(DecompressWithConfig(DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &TestDecompressPoolWithError{}, + })) + body := `{"name": "echo"}` + req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + reqBody, err := ioutil.ReadAll(c.Request().Body) + assert.NoError(t, err) + assert.Equal(t, body, string(reqBody)) + assert.Equal(t, rec.Code, http.StatusInternalServerError) +} + func BenchmarkDecompress(b *testing.B) { e := echo.New() body := `{"name": "echo"}`