Skip to content

Commit

Permalink
Allow no response to be send when a connection is hijacked (#712)
Browse files Browse the repository at this point in the history
* Allow no response to be send when a connection is hijacked

At the moment there is always a HTTP response before the connection gets
hijacked. This second option to Hijack() prevents this response from
being send.

Fixes: #698

* Add HijackSetNoResponse method instead
  • Loading branch information
erikdubbelboer authored Dec 29, 2019
1 parent 0724b3e commit 958ed36
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 25 deletions.
66 changes: 41 additions & 25 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,8 @@ type RequestCtx struct {
timeoutCh chan struct{}
timeoutTimer *time.Timer

hijackHandler HijackHandler
hijackHandler HijackHandler
hijackNoResponse bool
}

// HijackHandler must process the hijacked connection c.
Expand All @@ -535,6 +536,7 @@ type HijackHandler func(c net.Conn)
// * Unexpected error during response writing to the connection.
//
// The server stops processing requests from hijacked connections.
//
// Server limits such as Concurrency, ReadTimeout, WriteTimeout, etc.
// aren't applied to hijacked connections.
//
Expand All @@ -550,6 +552,15 @@ func (ctx *RequestCtx) Hijack(handler HijackHandler) {
ctx.hijackHandler = handler
}

// HijackSetNoResponse changes the behavior of hijacking a request.
// If HijackSetNoResponse is called with false fasthttp will send a response
// to the client before calling the HijackHandler (default). If HijackSetNoResponse
// is called with true no response is send back before calling the
// HijackHandler supplied in the Hijack function.
func (ctx *RequestCtx) HijackSetNoResponse(noResponse bool) {
ctx.hijackNoResponse = noResponse
}

// Hijacked returns true after Hijack is called.
func (ctx *RequestCtx) Hijacked() bool {
return ctx.hijackHandler != nil
Expand Down Expand Up @@ -1869,9 +1880,10 @@ func (s *Server) serveConn(c net.Conn) error {
br *bufio.Reader
bw *bufio.Writer

err error
timeoutResponse *Response
hijackHandler HijackHandler
err error
timeoutResponse *Response
hijackHandler HijackHandler
hijackNoResponse bool

connectionClose bool
isHTTP11 bool
Expand Down Expand Up @@ -2044,6 +2056,8 @@ func (s *Server) serveConn(c net.Conn) error {

hijackHandler = ctx.hijackHandler
ctx.hijackHandler = nil
hijackNoResponse = ctx.hijackNoResponse
ctx.hijackNoResponse = false

ctx.userValues.Reset()

Expand Down Expand Up @@ -2071,30 +2085,32 @@ func (s *Server) serveConn(c net.Conn) error {
ctx.Response.Header.SetServerBytes(serverName)
}

if bw == nil {
bw = acquireWriter(ctx)
}
if err = writeResponse(ctx, bw); err != nil {
break
}
if !hijackNoResponse {
if bw == nil {
bw = acquireWriter(ctx)
}
if err = writeResponse(ctx, bw); err != nil {
break
}

// Only flush the writer if we don't have another request in the pipeline.
// This is a big of an ugly optimization for https://www.techempower.com/benchmarks/
// This benchmark will send 16 pipelined requests. It is faster to pack as many responses
// in a TCP packet and send it back at once than waiting for a flush every request.
// In real world circumstances this behaviour could be argued as being wrong.
if br == nil || br.Buffered() == 0 || connectionClose {
err = bw.Flush()
if err != nil {
// Only flush the writer if we don't have another request in the pipeline.
// This is a big of an ugly optimization for https://www.techempower.com/benchmarks/
// This benchmark will send 16 pipelined requests. It is faster to pack as many responses
// in a TCP packet and send it back at once than waiting for a flush every request.
// In real world circumstances this behaviour could be argued as being wrong.
if br == nil || br.Buffered() == 0 || connectionClose {
err = bw.Flush()
if err != nil {
break
}
}
if connectionClose {
break
}
}
if connectionClose {
break
}
if s.ReduceMemoryUsage {
releaseWriter(s, bw)
bw = nil
if s.ReduceMemoryUsage && hijackHandler == nil {
releaseWriter(s, bw)
bw = nil
}
}

if hijackHandler != nil {
Expand Down
45 changes: 45 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2098,6 +2098,51 @@ func TestRequestCtxHijack(t *testing.T) {
}
}

func TestRequestCtxHijackNoResponse(t *testing.T) {
t.Parallel()

hijackDone := make(chan error)
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Hijack(func(c net.Conn) {
_, err := c.Write([]byte("test"))
hijackDone <- err
})
ctx.HijackSetNoResponse(true)
},
}

rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 0\r\n\r\n")

ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()

select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
}

select {
case err := <-hijackDone:
if err != nil {
t.Fatalf("Unexpected error from hijack: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
}

if got := rw.w.String(); got != "test" {
t.Errorf(`expected "test", got %q`, got)
}
}

func TestRequestCtxInit(t *testing.T) {
var ctx RequestCtx
var logger testLogger
Expand Down

0 comments on commit 958ed36

Please sign in to comment.