Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions middleware/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package middleware

import (
"context"
"crypto/tls"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -130,7 +131,7 @@ var DefaultProxyConfig = ProxyConfig{
ContextKey: "target",
}

func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
in, _, err := c.Response().Hijack()
if err != nil {
Expand All @@ -139,12 +140,33 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
}
defer in.Close()

out, err := net.Dial("tcp", t.URL.Host)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
return
var out net.Conn
if c.IsTLS() {
transport, ok := config.Transport.(*http.Transport)
if !ok {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, "proxy raw, invalid transport type"))
return
}

if transport.TLSClientConfig == nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, "proxy raw, TLSClientConfig is not set"))
return
}

out, err = tls.Dial("tcp", t.URL.Host, transport.TLSClientConfig)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
return
}
defer out.Close()
} else {
out, err = net.Dial("tcp", t.URL.Host)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
return
}
defer out.Close()
}
defer out.Close()

// Write header
err = r.Write(out)
Expand Down Expand Up @@ -365,7 +387,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req)
proxyRaw(tgt, c, config).ServeHTTP(res, req)
default: // even SSE requests
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
}
Expand Down
137 changes: 137 additions & 0 deletions middleware/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package middleware
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
Expand All @@ -20,6 +21,7 @@ import (

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"golang.org/x/net/websocket"
)

// Assert expected with url.EscapedPath method to obtain the path.
Expand Down Expand Up @@ -810,3 +812,138 @@ func TestModifyResponseUseContext(t *testing.T) {
assert.Equal(t, "OK", rec.Body.String())
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
}

func TestProxyWithConfigWebSocketTCP(t *testing.T) {
/*
Arrange
*/
e := echo.New()

// Create a WebSocket test server
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsHandler := func(conn *websocket.Conn) {
defer conn.Close()
for {
var msg string
err := websocket.Message.Receive(conn, &msg)
if err != nil {
return
}
// message back to the client
websocket.Message.Send(conn, msg)
}
}
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
}))
defer srv.Close()

tgtURL, _ := url.Parse(srv.URL)
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})

e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))

ts := httptest.NewServer(e)
defer ts.Close()

tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "ws"
tsURL.Path = "/"

/*
Act
*/

// Connect to the proxy WebSocket
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
assert.NoError(t, err)
defer wsConn.Close()

// Send message
sendMsg := "Hello, WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)

/*
Assert
*/
// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}

func TestProxyWithConfigWebSocketTLS(t *testing.T) {
/*
Arrange
*/
e := echo.New()

// Create a WebSocket test server
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsHandler := func(conn *websocket.Conn) {
defer conn.Close()
for {
var msg string
err := websocket.Message.Receive(conn, &msg)
if err != nil {
return
}
// message back to the client
websocket.Message.Send(conn, msg)
}
}
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
}))
defer srv.Close()

// create proxy server
tgtURL, _ := url.Parse(srv.URL)
tgtURL.Scheme = "wss"

balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})

defaultTransport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
t.Fatal("Default transport is not of type *http.Transport")
}
transport := defaultTransport.Clone()
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport}))

// Start test server
ts := httptest.NewTLSServer(e)
defer ts.Close()

tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "wss"
tsURL.Path = "/"

/*
Act
*/
origin, err := url.Parse(ts.URL)
assert.NoError(t, err)
config := &websocket.Config{
Location: tsURL,
Origin: origin,
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
Version: websocket.ProtocolVersionHybi13,
}
wsConn, err := websocket.DialConfig(config)
assert.NoError(t, err)
defer wsConn.Close()

// Send message
sendMsg := "Hello, TLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)

// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}
Loading