Skip to content

Add support for TLS WebSocket proxy #2762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions middleware/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package middleware

import (
"context"
"crypto/tls"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -130,21 +131,33 @@ var DefaultProxyConfig = ProxyConfig{
ContextKey: "target",
}

func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
if transport, ok := config.Transport.(*http.Transport); ok {
if transport.TLSClientConfig != nil {
d := tls.Dialer{
Config: transport.TLSClientConfig,
}
dialFunc = d.DialContext
}
}
if dialFunc == nil {
var d net.Dialer
dialFunc = d.DialContext
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
in, _, err := c.Response().Hijack()
if err != nil {
c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
return
}
defer in.Close()

out, err := net.Dial("tcp", t.URL.Host)
out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
return
}
defer out.Close()

// Write header
err = r.Write(out)
Expand Down Expand Up @@ -365,7 +378,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req)
proxyRaw(tgt, c, config).ServeHTTP(res, req)
default: // even SSE requests
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
}
Expand Down
230 changes: 230 additions & 0 deletions middleware/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package middleware
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
Expand All @@ -20,6 +21,7 @@ import (

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"golang.org/x/net/websocket"
)

// Assert expected with url.EscapedPath method to obtain the path.
Expand Down Expand Up @@ -810,3 +812,231 @@ func TestModifyResponseUseContext(t *testing.T) {
assert.Equal(t, "OK", rec.Body.String())
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
}

func createSimpleWebSocketServer(serveTLS bool) *httptest.Server {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsHandler := func(conn *websocket.Conn) {
defer conn.Close()
for {
var msg string
err := websocket.Message.Receive(conn, &msg)
if err != nil {
return
}
// message back to the client
websocket.Message.Send(conn, msg)
}
}
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
})
if serveTLS {
return httptest.NewTLSServer(handler)
}
return httptest.NewServer(handler)
}

func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server {
e := echo.New()

if toTLS {
// proxy to tls target
tgtURL, _ := url.Parse(srv.URL)
tgtURL.Scheme = "wss"
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})

defaultTransport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
t.Fatal("Default transport is not of type *http.Transport")
}
transport := defaultTransport.Clone()
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport}))
} else {
// proxy to non-TLS target
tgtURL, _ := url.Parse(srv.URL)
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))
}

if serveTLS {
// serve proxy server with TLS
ts := httptest.NewTLSServer(e)
return ts
}
// serve proxy server without TLS
ts := httptest.NewServer(e)
return ts
}

// TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection.
func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) {
/*
Arrange
*/
// Create a WebSocket test server (non-TLS)
srv := createSimpleWebSocketServer(false)
defer srv.Close()

// create proxy server (non-TLS to non-TLS)
ts := createSimpleProxyServer(t, srv, false, false)
defer ts.Close()

tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "ws"
tsURL.Path = "/"

/*
Act
*/

// Connect to the proxy WebSocket
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
assert.NoError(t, err)
defer wsConn.Close()

// Send message
sendMsg := "Hello, Non TLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)

/*
Assert
*/
// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}

// TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection.
func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) {
/*
Arrange
*/
// Create a WebSocket test server (TLS)
srv := createSimpleWebSocketServer(true)
defer srv.Close()

// create proxy server (TLS to TLS)
ts := createSimpleProxyServer(t, srv, true, true)
defer ts.Close()

tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "wss"
tsURL.Path = "/"

/*
Act
*/
origin, err := url.Parse(ts.URL)
assert.NoError(t, err)
config := &websocket.Config{
Location: tsURL,
Origin: origin,
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
Version: websocket.ProtocolVersionHybi13,
}
wsConn, err := websocket.DialConfig(config)
assert.NoError(t, err)
defer wsConn.Close()

// Send message
sendMsg := "Hello, TLS to TLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)

// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}

// TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection.
func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) {
/*
Arrange
*/

// Create a WebSocket test server (TLS)
srv := createSimpleWebSocketServer(true)
defer srv.Close()

// create proxy server (Non-TLS to TLS)
ts := createSimpleProxyServer(t, srv, false, true)
defer ts.Close()

tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "ws"
tsURL.Path = "/"

/*
Act
*/
// Connect to the proxy WebSocket
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
assert.NoError(t, err)
defer wsConn.Close()

// Send message
sendMsg := "Hello, Non TLS to TLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)

/*
Assert
*/
// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}

// TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination)
func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) {
/*
Arrange
*/

// Create a WebSocket test server (non-TLS)
srv := createSimpleWebSocketServer(false)
defer srv.Close()

// create proxy server (TLS to non-TLS)
ts := createSimpleProxyServer(t, srv, true, false)
defer ts.Close()

tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "wss"
tsURL.Path = "/"

/*
Act
*/
origin, err := url.Parse(ts.URL)
assert.NoError(t, err)
config := &websocket.Config{
Location: tsURL,
Origin: origin,
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
Version: websocket.ProtocolVersionHybi13,
}
wsConn, err := websocket.DialConfig(config)
assert.NoError(t, err)
defer wsConn.Close()

// Send message
sendMsg := "Hello, TLS to NoneTLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)

// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}
Loading