Skip to content

Commit 9c0c71e

Browse files
authored
reverseproxy: Rewrite requests and responses for websocket over http2 (#6567)
* reverse proxy: rewrite requests and responses for websocket over http2 * delete protocol pseudo-header * modify cloned requests * set request variable to track if it's a h2 websocket * use request bodu * rewrite request body * use WebSocket instead of Websocket in the headers * use logger check for zap loggers * fix lint
1 parent a1751ad commit 9c0c71e

File tree

2 files changed

+88
-13
lines changed

2 files changed

+88
-13
lines changed

modules/caddyhttp/reverseproxy/reverseproxy.go

+19
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ package reverseproxy
1717
import (
1818
"bytes"
1919
"context"
20+
"crypto/rand"
21+
"encoding/base64"
2022
"encoding/json"
2123
"errors"
2224
"fmt"
@@ -394,6 +396,23 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
394396
return caddyhttp.Error(http.StatusInternalServerError,
395397
fmt.Errorf("preparing request for upstream round-trip: %v", err))
396398
}
399+
// websocket over http2, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
400+
// TODO: once we can reliably detect backend support this, it can be removed for those backends
401+
if r.ProtoMajor == 2 && r.Method == http.MethodConnect && r.Header.Get(":protocol") != "" {
402+
clonedReq.Header.Del(":protocol")
403+
// keep the body for later use. http1.1 upgrade uses http.NoBody
404+
caddyhttp.SetVar(clonedReq.Context(), "h2_websocket_body", clonedReq.Body)
405+
clonedReq.Body = http.NoBody
406+
clonedReq.Method = http.MethodGet
407+
clonedReq.Header.Set("Upgrade", r.Header.Get(":protocol"))
408+
clonedReq.Header.Set("Connection", "Upgrade")
409+
key := make([]byte, 16)
410+
_, randErr := rand.Read(key)
411+
if randErr != nil {
412+
return randErr
413+
}
414+
clonedReq.Header["Sec-WebSocket-Key"] = []string{base64.StdEncoding.EncodeToString(key)}
415+
}
397416

398417
// we will need the original headers and Host value if
399418
// header operations are configured; this is so that each

modules/caddyhttp/reverseproxy/streaming.go

+69-13
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package reverseproxy
2020

2121
import (
22+
"bufio"
2223
"context"
2324
"errors"
2425
"fmt"
@@ -33,8 +34,29 @@ import (
3334
"go.uber.org/zap"
3435
"go.uber.org/zap/zapcore"
3536
"golang.org/x/net/http/httpguts"
37+
38+
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
3639
)
3740

41+
type h2ReadWriteCloser struct {
42+
io.ReadCloser
43+
http.ResponseWriter
44+
}
45+
46+
func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) {
47+
n, err = rwc.ResponseWriter.Write(p)
48+
if err != nil {
49+
return 0, err
50+
}
51+
52+
//nolint:bodyclose
53+
err = http.NewResponseController(rwc.ResponseWriter).Flush()
54+
if err != nil {
55+
return 0, err
56+
}
57+
return n, nil
58+
}
59+
3860
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) {
3961
reqUpType := upgradeType(req.Header)
4062
resUpType := upgradeType(res.Header)
@@ -67,24 +89,58 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
6789
// like the rest of handler chain.
6890
copyHeader(rw.Header(), res.Header)
6991
normalizeWebsocketHeaders(rw.Header())
70-
rw.WriteHeader(res.StatusCode)
7192

72-
logger.Debug("upgrading connection")
93+
var (
94+
conn io.ReadWriteCloser
95+
brw *bufio.ReadWriter
96+
)
97+
// websocket over http2, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
98+
// TODO: once we can reliably detect backend support this, it can be removed for those backends
99+
if body, ok := caddyhttp.GetVar(req.Context(), "h2_websocket_body").(io.ReadCloser); ok {
100+
req.Body = body
101+
rw.Header().Del("Upgrade")
102+
rw.Header().Del("Connection")
103+
delete(rw.Header(), "Sec-WebSocket-Accept")
104+
rw.WriteHeader(http.StatusOK)
105+
106+
if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
107+
c.Write(zap.Int("http_version", 2))
108+
}
73109

74-
//nolint:bodyclose
75-
conn, brw, hijackErr := http.NewResponseController(rw).Hijack()
76-
if errors.Is(hijackErr, http.ErrNotSupported) {
77-
if c := logger.Check(zapcore.ErrorLevel, "can't switch protocols using non-Hijacker ResponseWriter"); c != nil {
78-
c.Write(zap.String("type", fmt.Sprintf("%T", rw)))
110+
//nolint:bodyclose
111+
flushErr := http.NewResponseController(rw).Flush()
112+
if flushErr != nil {
113+
if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil {
114+
c.Write(zap.Error(flushErr))
115+
}
116+
return
79117
}
80-
return
81-
}
118+
conn = h2ReadWriteCloser{req.Body, rw}
119+
// bufio is not needed, use minimal buffer
120+
brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1))
121+
} else {
122+
rw.WriteHeader(res.StatusCode)
82123

83-
if hijackErr != nil {
84-
if c := logger.Check(zapcore.ErrorLevel, "hijack failed on protocol switch"); c != nil {
85-
c.Write(zap.Error(hijackErr))
124+
if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
125+
c.Write(zap.Int("http_version", req.ProtoMajor))
126+
}
127+
128+
var hijackErr error
129+
//nolint:bodyclose
130+
conn, brw, hijackErr = http.NewResponseController(rw).Hijack()
131+
if errors.Is(hijackErr, http.ErrNotSupported) {
132+
if c := h.logger.Check(zap.ErrorLevel, "can't switch protocols using non-Hijacker ResponseWriter"); c != nil {
133+
c.Write(zap.String("type", fmt.Sprintf("%T", rw)))
134+
}
135+
return
136+
}
137+
138+
if hijackErr != nil {
139+
if c := h.logger.Check(zap.ErrorLevel, "hijack failed on protocol switch"); c != nil {
140+
c.Write(zap.Error(hijackErr))
141+
}
142+
return
86143
}
87-
return
88144
}
89145

90146
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5

0 commit comments

Comments
 (0)