diff --git a/fasthttpadaptor/adaptor.go b/fasthttpadaptor/adaptor.go index dcd43e4431..5e856fb5b9 100644 --- a/fasthttpadaptor/adaptor.go +++ b/fasthttpadaptor/adaptor.go @@ -3,8 +3,11 @@ package fasthttpadaptor import ( + "bufio" "io" + "net" "net/http" + "sync" "github.com/valyala/fasthttp" ) @@ -53,8 +56,10 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler { ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) return } - - w := netHTTPResponseWriter{w: ctx.Response.BodyWriter()} + w := netHTTPResponseWriter{ + w: ctx.Response.BodyWriter(), + ctx: ctx, + } h.ServeHTTP(&w, r.WithContext(ctx)) ctx.SetStatusCode(w.StatusCode()) @@ -86,6 +91,7 @@ type netHTTPResponseWriter struct { statusCode int h http.Header w io.Writer + ctx *fasthttp.RequestCtx } func (w *netHTTPResponseWriter) StatusCode() int { @@ -111,3 +117,43 @@ func (w *netHTTPResponseWriter) Write(p []byte) (int, error) { } func (w *netHTTPResponseWriter) Flush() {} + +type wrappedConn struct { + net.Conn + + wg sync.WaitGroup + once sync.Once +} + +func (c *wrappedConn) Close() (err error) { + c.once.Do(func() { + err = c.Conn.Close() + c.wg.Done() + }) + return +} + +func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + // Hijack assumes control of the connection, so we need to prevent fasthttp from closing it or + // doing anything else with it. + w.ctx.HijackSetNoResponse(true) + + conn := &wrappedConn{Conn: w.ctx.Conn()} + conn.wg.Add(1) + w.ctx.Hijack(func(net.Conn) { + conn.wg.Wait() + }) + + bufW := bufio.NewWriter(conn) + + // Write any unflushed body to the hijacked connection buffer. + unflushedBody := w.ctx.Response.Body() + if len(unflushedBody) > 0 { + if _, err := bufW.Write(unflushedBody); err != nil { + conn.Close() + return nil, nil, err + } + } + + return conn, &bufio.ReadWriter{Reader: bufio.NewReader(conn), Writer: bufW}, nil +} diff --git a/fasthttpadaptor/adaptor_test.go b/fasthttpadaptor/adaptor_test.go index a8a1ae830a..172ee54090 100644 --- a/fasthttpadaptor/adaptor_test.go +++ b/fasthttpadaptor/adaptor_test.go @@ -7,8 +7,10 @@ import ( "net/url" "reflect" "testing" + "time" "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" ) func TestNewFastHTTPHandler(t *testing.T) { @@ -143,3 +145,74 @@ func setContextValueMiddleware(next fasthttp.RequestHandler, key string, value i next(ctx) } } + +func TestHijack(t *testing.T) { + t.Parallel() + + nethttpH := func(w http.ResponseWriter, r *http.Request) { + if f, ok := w.(http.Hijacker); !ok { + t.Errorf("expected http.ResponseWriter to implement http.Hijacker") + } else { + if _, err := w.Write([]byte("foo")); err != nil { + t.Error(err) + } + + if c, rw, err := f.Hijack(); err != nil { + t.Error(err) + } else { + if _, err := rw.Write([]byte("bar")); err != nil { + t.Error(err) + } + + if err := rw.Flush(); err != nil { + t.Error(err) + } + + if err := c.Close(); err != nil { + t.Error(err) + } + } + } + } + + s := &fasthttp.Server{ + Handler: NewFastHTTPHandler(http.HandlerFunc(nethttpH)), + } + + ln := fasthttputil.NewInmemoryListener() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %v", err) + } + }() + + clientCh := make(chan struct{}) + go func() { + c, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %v", err) + } + + buf, err := io.ReadAll(c) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if string(buf) != "foobar" { + t.Errorf("unexpected response: %q. Expecting %q", buf, "foobar") + } + + close(clientCh) + }() + + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } +}