Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adaptor ResponseWriter - adding Hijack method and pass proper fields #1525

Merged
merged 6 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions fasthttpadaptor/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
package fasthttpadaptor

import (
"bufio"
"io"
"net"
"net/http"

"github.com/valyala/fasthttp"
Expand Down Expand Up @@ -45,16 +47,19 @@ func NewFastHTTPHandlerFunc(h http.HandlerFunc) fasthttp.RequestHandler {
// So it is advisable using this function only for quick net/http -> fasthttp
// switching. Then manually convert net/http handlers to fasthttp handlers
// according to https://github.com/valyala/fasthttp#switching-from-nethttp-to-fasthttp .
func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler {
//
// hijackHandler is used for registering handler for connection hijacking, this is usefull for cases
// where there is no access to change the server KeepHijackedConns field (which is default as false)
// it also can be used for additional custom hijacking logic
func NewFastHTTPHandler(h http.Handler, hijackHandler ...func(net.Conn)) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
var r http.Request
if err := ConvertRequest(ctx, &r, true); err != nil {
ctx.Logger().Printf("cannot parse requestURI %q: %v", r.RequestURI, err)
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return
}

w := netHTTPResponseWriter{w: ctx.Response.BodyWriter()}
w := netHTTPResponseWriter{w: ctx.Response.BodyWriter(), r: ctx.RequestBodyStream(), conn: ctx.Conn(), ctx: ctx, hijackHandler: hijackHandler}
h.ServeHTTP(&w, r.WithContext(ctx))

ctx.SetStatusCode(w.StatusCode())
Expand Down Expand Up @@ -83,9 +88,13 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler {
}

type netHTTPResponseWriter struct {
statusCode int
h http.Header
w io.Writer
statusCode int
h http.Header
w io.Writer
r io.Reader
conn net.Conn
ctx *fasthttp.RequestCtx
hijackHandler []func(net.Conn)
}

func (w *netHTTPResponseWriter) StatusCode() int {
Expand All @@ -111,3 +120,10 @@ func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
}

func (w *netHTTPResponseWriter) Flush() {}

func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if len(w.hijackHandler) > 0 {
w.ctx.Hijack(w.hijackHandler[0])
}
return w.conn, &bufio.ReadWriter{Reader: bufio.NewReader(w.r), Writer: bufio.NewWriter(w.w)}, nil
}
76 changes: 76 additions & 0 deletions fasthttpadaptor/adaptor_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fasthttpadaptor

import (
"fmt"
"io"
"net"
"net/http"
Expand All @@ -9,6 +10,7 @@ import (
"testing"

"github.com/valyala/fasthttp"
"golang.org/x/sync/errgroup"
)

func TestNewFastHTTPHandler(t *testing.T) {
Expand Down Expand Up @@ -143,3 +145,77 @@ func setContextValueMiddleware(next fasthttp.RequestHandler, key string, value i
next(ctx)
}
}

func TestHijackInterface1(t *testing.T) {
g := errgroup.Group{}

var (
reqCtx fasthttp.RequestCtx
req fasthttp.Request
)

client, server := net.Pipe()
testmsgc2s := "hello from a client"
testmsgs2c := "hello from a hijacked request"

reqCtx.Init2(server, nil, true)
req.CopyTo(&reqCtx.Request)

nethttpH := func(w http.ResponseWriter, r *http.Request) {
if h, ok := w.(http.Hijacker); !ok {
t.Fatalf("response writer do not support hijack interface")
} else if netConn, _, err := h.Hijack(); err != nil {
t.Fatalf("invoking Hijack failed: %s", err)
} else if netConn == nil {
t.Fatalf("invalid conn handler for hijack invokation")
} else {
readMsg := make([]byte, len(testmsgc2s))
n, err := io.ReadAtLeast(netConn, readMsg, len(testmsgc2s))
if err != nil {
t.Fatalf("server: error on read from conn: %s", err)
}
if n != len(testmsgc2s) || testmsgc2s != string(readMsg) {
t.Fatalf("server: mismatch on message recieved: expected: (%d)<%s>, actual: (%d)<%s>\n", len(testmsgc2s), testmsgc2s, n, string(readMsg))
}
n, err = io.WriteString(netConn, testmsgs2c)
if err != nil {
t.Fatalf("server: error on write to conn: %s", err)
}
if n != len(testmsgs2c) {
t.Fatalf("server: mismatch on message sent size: expected: (%d), actual: (%d)\n", len(testmsgc2s), n)
}
netConn.Close()
}
}

g.Go(func() error {
n, err := io.WriteString(client, testmsgc2s)
if err != nil {
return fmt.Errorf("client: error on write to conn: %s\n", err)
}
if n != len(testmsgc2s) {
return fmt.Errorf("client: mismatch on send all: expected: %d, actual: %d\n", len(testmsgc2s), n)
}
readMsg := make([]byte, len(testmsgs2c))
n, err = io.ReadAtLeast(client, readMsg, len(testmsgs2c))
if err != nil {
return fmt.Errorf("client: error on read from conn: %s", err)
}
if n != len(testmsgs2c) || testmsgs2c != string(readMsg) {
return fmt.Errorf("client: mismatch on message recieved: expected: (%d)<%s>, actual: (%d)<%s>\n", len(testmsgs2c), testmsgs2c, n, string(readMsg))
}
return nil
})

g.Go(func() error {
fasthttpH := NewFastHTTPHandler(http.HandlerFunc(nethttpH), func(c net.Conn) {})
fasthttpH(&reqCtx)
if !reqCtx.Hijacked() {
t.Fatal("request was not hijacked")
}
return nil
})
if err := g.Wait(); err != nil {
t.Fatal(err)
}
}