diff --git a/middleware/proxy.go b/middleware/proxy.go index 495970aca..2744bc4a8 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -5,6 +5,7 @@ package middleware import ( "context" + "crypto/tls" "fmt" "io" "math/rand" @@ -130,7 +131,21 @@ var DefaultProxyConfig = ProxyConfig{ ContextKey: "target", } -func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { +func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { + var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + if transport, ok := config.Transport.(*http.Transport); ok { + if transport.TLSClientConfig != nil { + d := tls.Dialer{ + Config: transport.TLSClientConfig, + } + dialFunc = d.DialContext + } + } + if dialFunc == nil { + var d net.Dialer + dialFunc = d.DialContext + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { in, _, err := c.Response().Hijack() if err != nil { @@ -138,13 +153,11 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { return } defer in.Close() - - out, err := net.Dial("tcp", t.URL.Host) + out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host) if err != nil { c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL))) return } - defer out.Close() // Write header err = r.Write(out) @@ -365,7 +378,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // Proxy switch { case c.IsWebSocket(): - proxyRaw(tgt, c).ServeHTTP(res, req) + proxyRaw(tgt, c, config).ServeHTTP(res, req) default: // even SSE requests proxyHTTP(tgt, c, config).ServeHTTP(res, req) } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index e87229ab5..dbf07648b 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -6,6 +6,7 @@ package middleware import ( "bytes" "context" + "crypto/tls" "errors" "fmt" "io" @@ -20,6 +21,7 @@ import ( "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" + "golang.org/x/net/websocket" ) // Assert expected with url.EscapedPath method to obtain the path. @@ -810,3 +812,231 @@ func TestModifyResponseUseContext(t *testing.T) { assert.Equal(t, "OK", rec.Body.String()) assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER")) } + +func createSimpleWebSocketServer(serveTLS bool) *httptest.Server { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsHandler := func(conn *websocket.Conn) { + defer conn.Close() + for { + var msg string + err := websocket.Message.Receive(conn, &msg) + if err != nil { + return + } + // message back to the client + websocket.Message.Send(conn, msg) + } + } + websocket.Server{Handler: wsHandler}.ServeHTTP(w, r) + }) + if serveTLS { + return httptest.NewTLSServer(handler) + } + return httptest.NewServer(handler) +} + +func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server { + e := echo.New() + + if toTLS { + // proxy to tls target + tgtURL, _ := url.Parse(srv.URL) + tgtURL.Scheme = "wss" + balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}}) + + defaultTransport, ok := http.DefaultTransport.(*http.Transport) + if !ok { + t.Fatal("Default transport is not of type *http.Transport") + } + transport := defaultTransport.Clone() + transport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport})) + } else { + // proxy to non-TLS target + tgtURL, _ := url.Parse(srv.URL) + balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}}) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer})) + } + + if serveTLS { + // serve proxy server with TLS + ts := httptest.NewTLSServer(e) + return ts + } + // serve proxy server without TLS + ts := httptest.NewServer(e) + return ts +} + +// TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection. +func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) { + /* + Arrange + */ + // Create a WebSocket test server (non-TLS) + srv := createSimpleWebSocketServer(false) + defer srv.Close() + + // create proxy server (non-TLS to non-TLS) + ts := createSimpleProxyServer(t, srv, false, false) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "ws" + tsURL.Path = "/" + + /* + Act + */ + + // Connect to the proxy WebSocket + wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/") + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, Non TLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + /* + Assert + */ + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +} + +// TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection. +func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) { + /* + Arrange + */ + // Create a WebSocket test server (TLS) + srv := createSimpleWebSocketServer(true) + defer srv.Close() + + // create proxy server (TLS to TLS) + ts := createSimpleProxyServer(t, srv, true, true) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "wss" + tsURL.Path = "/" + + /* + Act + */ + origin, err := url.Parse(ts.URL) + assert.NoError(t, err) + config := &websocket.Config{ + Location: tsURL, + Origin: origin, + TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing + Version: websocket.ProtocolVersionHybi13, + } + wsConn, err := websocket.DialConfig(config) + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, TLS to TLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +} + +// TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection. +func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) { + /* + Arrange + */ + + // Create a WebSocket test server (TLS) + srv := createSimpleWebSocketServer(true) + defer srv.Close() + + // create proxy server (Non-TLS to TLS) + ts := createSimpleProxyServer(t, srv, false, true) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "ws" + tsURL.Path = "/" + + /* + Act + */ + // Connect to the proxy WebSocket + wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/") + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, Non TLS to TLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + /* + Assert + */ + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +} + +// TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination) +func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) { + /* + Arrange + */ + + // Create a WebSocket test server (non-TLS) + srv := createSimpleWebSocketServer(false) + defer srv.Close() + + // create proxy server (TLS to non-TLS) + ts := createSimpleProxyServer(t, srv, true, false) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "wss" + tsURL.Path = "/" + + /* + Act + */ + origin, err := url.Parse(ts.URL) + assert.NoError(t, err) + config := &websocket.Config{ + Location: tsURL, + Origin: origin, + TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing + Version: websocket.ProtocolVersionHybi13, + } + wsConn, err := websocket.DialConfig(config) + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, TLS to NoneTLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +}