From ac3d87d0dc96ba404c81dc7ba580561bb656aff3 Mon Sep 17 00:00:00 2001 From: Daniel Morsing Date: Sat, 6 Dec 2014 14:12:04 +0000 Subject: [PATCH 1/2] unify Headers and Push Promise writes. As predicted, I ended up unifying the Headers and Push Promise packing. This is in preparation for reusing the writeResHeaders functionality of sending continuation frames if the headers are too big. --- frame.go | 84 +++++++++++++++----------------------------------- frame_test.go | 5 +-- server_test.go | 9 +++--- 3 files changed, 32 insertions(+), 66 deletions(-) diff --git a/frame.go b/frame.go index 37a5f10d..4339ece1 100644 --- a/frame.go +++ b/frame.go @@ -771,6 +771,9 @@ func parseHeadersFrame(fh FrameHeader, p []byte) (_ Frame, err error) { // HeadersFrameParam are the parameters for writing a HEADERS frame. type HeadersFrameParam struct { + // Write this header as a push promise. + PushPromise bool + // StreamID is the required Stream ID to initiate. StreamID uint32 // BlockFragment is part (or all) of a Header Block. @@ -779,7 +782,7 @@ type HeadersFrameParam struct { // EndStream indicates that the header block is the last that // the endpoint will send for the identified stream. Setting // this flag causes the stream to enter one of "half closed" - // states. + // states. It is not sent if this write is a push promise. EndStream bool // EndHeaders indicates that this frame contains an entire @@ -792,11 +795,14 @@ type HeadersFrameParam struct { PadLength uint8 // Priority, if non-zero, includes stream priority information - // in the HEADER frame. + // in the HEADER frame. It is not sent if this write is a PushPromise. Priority PriorityParam + + // PromiseID is the promised stream if this write is a PushPromise. + PromiseID uint32 } -// WriteHeaders writes a single HEADERS frame. +// WriteHeaders writes a single HEADERS or PUSH_PROMISE frame. // // This is a low-level header writing method. Encoding headers and // splitting them into any necessary CONTINUATION frames is handled @@ -812,20 +818,24 @@ func (f *Framer) WriteHeaders(p HeadersFrameParam) error { if p.PadLength != 0 { flags |= FlagHeadersPadded } - if p.EndStream { + if p.EndStream && !p.PushPromise { flags |= FlagHeadersEndStream } if p.EndHeaders { flags |= FlagHeadersEndHeaders } - if !p.Priority.IsZero() { + if !p.Priority.IsZero() && !p.PushPromise { flags |= FlagHeadersPriority } - f.startWrite(FrameHeaders, flags, p.StreamID) + fh := FrameHeaders + if p.PushPromise { + fh = FramePushPromise + } + f.startWrite(fh, flags, p.StreamID) if p.PadLength != 0 { f.writeByte(p.PadLength) } - if !p.Priority.IsZero() { + if !p.Priority.IsZero() && !p.PushPromise { v := p.Priority.StreamDep if !validStreamID(v) && !f.AllowIllegalWrites { return errors.New("invalid dependent stream id") @@ -836,6 +846,13 @@ func (f *Framer) WriteHeaders(p HeadersFrameParam) error { f.writeUint32(v) f.writeByte(p.Priority.Weight) } + if p.PushPromise { + if !validStreamID(p.PromiseID) && !f.AllowIllegalWrites { + return errStreamID + } + f.writeUint32(p.PromiseID) + } + f.wbuf = append(f.wbuf, p.BlockFragment...) f.wbuf = append(f.wbuf, padZeros[:p.PadLength]...) return f.endWrite() @@ -1026,59 +1043,6 @@ func parsePushPromise(fh FrameHeader, p []byte) (_ Frame, err error) { return pp, nil } -// PushPromiseParam are the parameters for writing a PUSH_PROMISE frame. -type PushPromiseParam struct { - // StreamID is the required Stream ID to initiate. - StreamID uint32 - - // PromiseID is the required Stream ID which this - // Push Promises - PromiseID uint32 - - // BlockFragment is part (or all) of a Header Block. - BlockFragment []byte - - // EndHeaders indicates that this frame contains an entire - // header block and is not followed by any - // CONTINUATION frames. - EndHeaders bool - - // PadLength is the optional number of bytes of zeros to add - // to this frame. - PadLength uint8 -} - -// WritePushPromise writes a single PushPromise Frame. -// -// As with Header Frames, This is the low level call for writing -// individual frames. Continuation frames are handled elsewhere. -// -// It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility to not call other Write methods concurrently. -func (f *Framer) WritePushPromise(p PushPromiseParam) error { - if !validStreamID(p.StreamID) && !f.AllowIllegalWrites { - return errStreamID - } - var flags Flags - if p.PadLength != 0 { - flags |= FlagPushPromisePadded - } - if p.EndHeaders { - flags |= FlagPushPromiseEndHeaders - } - f.startWrite(FramePushPromise, flags, p.StreamID) - if p.PadLength != 0 { - f.writeByte(p.PadLength) - } - if !validStreamID(p.PromiseID) && !f.AllowIllegalWrites { - return errStreamID - } - f.writeUint32(p.PromiseID) - f.wbuf = append(f.wbuf, p.BlockFragment...) - f.wbuf = append(f.wbuf, padZeros[:p.PadLength]...) - return f.endWrite() -} - // WriteRawFrame writes a raw frame. This can be used to write // extension frames unknown to this package. func (f *Framer) WriteRawFrame(t FrameType, flags Flags, streamID uint32, payload []byte) error { diff --git a/frame_test.go b/frame_test.go index d09db62f..ec70c8c7 100644 --- a/frame_test.go +++ b/frame_test.go @@ -511,13 +511,14 @@ func TestWriteGoAway(t *testing.T) { } func TestWritePushPromise(t *testing.T) { - pp := PushPromiseParam{ + pp := HeadersFrameParam{ + PushPromise: true, StreamID: 42, PromiseID: 42, BlockFragment: []byte("abc"), } fr, buf := testFramer() - if err := fr.WritePushPromise(pp); err != nil { + if err := fr.WriteHeaders(pp); err != nil { t.Fatal(err) } const wantEnc = "\x00\x00\x07\x05\x00\x00\x00\x00*\x00\x00\x00*abc" diff --git a/server_test.go b/server_test.go index ef8b6e6a..6a3d63df 100644 --- a/server_test.go +++ b/server_test.go @@ -1191,11 +1191,12 @@ func TestServer_Rejects_Continuation0(t *testing.T) { func TestServer_Rejects_PushPromise(t *testing.T) { testServerRejects(t, func(st *serverTester) { - pp := PushPromiseParam{ - StreamID: 1, - PromiseID: 3, + pp := HeadersFrameParam{ + PushPromise: true, + StreamID: 1, + PromiseID: 3, } - if err := st.fr.WritePushPromise(pp); err != nil { + if err := st.fr.WriteHeaders(pp); err != nil { t.Fatal(err) } }) From c2f48436be1e99cabe0e4d1ed29300736d66d817 Mon Sep 17 00:00:00 2001 From: Daniel Morsing Date: Fri, 5 Dec 2014 22:30:25 +0000 Subject: [PATCH 2/2] initial implementation of server push The API uses loopback to the top handler in order to make sure that the resources being pushed are related to the resource being fetched. If the header parameter is nil, we copy the headers from the initiating request. This is mostly a shortcut for the common case where we don't want to specify any new request headers. --- h2demo/h2demo.go | 12 +++- server.go | 146 +++++++++++++++++++++++++++++++++++++++++++++-- server_test.go | 37 ++++++++++++ write.go | 45 +++++++++++++-- 4 files changed, 228 insertions(+), 12 deletions(-) diff --git a/h2demo/h2demo.go b/h2demo/h2demo.go index 4f92c9c2..6999a74b 100644 --- a/h2demo/h2demo.go +++ b/h2demo/h2demo.go @@ -285,6 +285,17 @@ func newGopherTilesHandler() http.Handler { return } } + cacheBust := time.Now().UnixNano() + + if p, ok := w.(http2.Pusher); ok { + for y := 0; y < yt; y++ { + for x := 0; x < xt; x++ { + path := fmt.Sprintf("/gophertiles?x=%d&y=%d&cachebust=%d&latency=%d", x, y, cacheBust, ms) + p.Push("GET", path, nil) + } + } + } + io.WriteString(w, "") fmt.Fprintf(w, "A grid of %d tiled images is below. Compare:

", xt*yt) for _, ms := range []int{0, 30, 200, 1000} { @@ -295,7 +306,6 @@ func newGopherTilesHandler() http.Handler { ) } io.WriteString(w, "

\n") - cacheBust := time.Now().UnixNano() for y := 0; y < yt; y++ { for x := 0; x < xt; x++ { fmt.Fprintf(w, "", diff --git a/server.go b/server.go index aba029aa..53ac9d6e 100644 --- a/server.go +++ b/server.go @@ -192,9 +192,11 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) { bw: newBufferedWriter(c), handler: h, streams: make(map[uint32]*stream), + maxPushID: 2, readFrameCh: make(chan frameAndGate), readFrameErrCh: make(chan error, 1), // must be buffered for 1 wantWriteFrameCh: make(chan frameWriteMsg, 8), + pushPromiseCh: make(chan *pushPromise, 1), wroteFrameCh: make(chan struct{}, 1), // buffered; one send in reading goroutine bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way doneServing: make(chan struct{}), @@ -316,6 +318,7 @@ type serverConn struct { readFrameCh chan frameAndGate // written by serverConn.readFrames readFrameErrCh chan error wantWriteFrameCh chan frameWriteMsg // from handlers -> serve + pushPromiseCh chan *pushPromise // server loop receives push promises on this chan wroteFrameCh chan struct{} // from writeFrameAsync -> serve, tickles more frame writes bodyReadCh chan bodyReadMsg // from handlers -> serve testHookCh chan func() // code to run on the serve loop @@ -334,6 +337,7 @@ type serverConn struct { advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client curOpenStreams uint32 // client's number of open streams maxStreamID uint32 // max ever seen + maxPushID uint32 // maxID used for push streams map[uint32]*stream initialWindowSize int32 headerTableSize uint32 @@ -354,9 +358,10 @@ type serverConn struct { hpackEncoder *hpack.Encoder } -// requestParam is the state of the next request, initialized over -// potentially several frames HEADERS + zero or more CONTINUATION -// frames. +// requestParam is the state of a request. +// It is used in 2 places. One is as a accumulator of the state in +// HEADERS + zero or more CONTINUATION frames in the read loop. +// The other place is for constructing a push promise frame. type requestParam struct { // stream is non-nil if we're reading (HEADER or CONTINUATION) // frames for a request (but not DATA). @@ -613,6 +618,8 @@ func (sc *serverConn) serve() { select { case wm := <-sc.wantWriteFrameCh: sc.writeFrame(wm) + case pp := <-sc.pushPromiseCh: + sc.startPushPromise(pp) case <-sc.wroteFrameCh: sc.writingFrame = false sc.scheduleFrameWrite() @@ -822,6 +829,28 @@ func (sc *serverConn) scheduleFrameWrite() { } } +func (sc *serverConn) pushPromise(pp *pushPromise) error { + sc.serveG.checkNotOn() // NOT + st := pp.reqpm.stream + select { + case sc.pushPromiseCh <- pp: + case <-sc.doneServing: + return errClientDisconnected + case <-st.cw: + return errStreamBroken + } + + select { + case err := <-pp.done: + return err + case <-sc.doneServing: + return errClientDisconnected + case <-st.cw: + return errStreamBroken + } + +} + func (sc *serverConn) goAway(code ErrCode) { sc.serveG.check() if sc.inGoAway { @@ -1223,6 +1252,71 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error { return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) } +var ErrPushDisabled = errors.New("http2: push attempted on connection where it is disabled") + +func (sc *serverConn) startPushPromise(pp *pushPromise) { + sc.serveG.check() + if !sc.pushEnabled { + pp.done <- ErrPushDisabled + return + } + assocStream := pp.reqpm.stream + promiseid := sc.maxPushID + sc.maxPushID += 2 + + // create the stream that is going to be pushed + st := &stream{ + id: promiseid, + state: stateResvLocal, + } + st.flow.conn = &sc.flow + st.flow.add(sc.initialWindowSize) + st.cw.Init() + sc.streams[promiseid] = st + + // TODO(dmorsing): figure out if priority is + // a factor between the initiating stream and + // the pushed one + + // A bit ugly: we use the stream field in + // reqpm to tell us which stream initiated this + // push, then overwrite it here with the created + // stream + pp.reqpm.stream = st + + // We need to make sure that the push + // promise frame was sent to the client + // before we start handler. + // Otherwise, we might send the headers + // for the response before the push promise. + starthandlerCh := make(chan error, 1) + sc.writeFrame(frameWriteMsg{ + &writePPHeaders{ + streamID: assocStream.id, + reqpm: &pp.reqpm, + }, + assocStream, + starthandlerCh}) + rw, req, err := sc.newWriterAndRequest(&pp.reqpm) + if err != nil { + panic("Created bad request for push") + } + go func() { + select { + case <-sc.doneServing: + return + case <-st.cw: + return + case err := <-starthandlerCh: + pp.done <- err + if err != nil { + return + } + } + sc.runHandler(rw, req) + }() +} + func (sc *serverConn) processContinuation(f *ContinuationFrame) error { sc.serveG.check() st := sc.streams[f.Header().StreamID] @@ -1265,7 +1359,7 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo return StreamError{st.id, ErrCodeRefusedStream} } - rw, req, err := sc.newWriterAndRequest() + rw, req, err := sc.newWriterAndRequest(&sc.req) if err != nil { return err } @@ -1308,9 +1402,8 @@ func (sc *serverConn) resetPendingRequest() { sc.req = requestParam{} } -func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, error) { +func (sc *serverConn) newWriterAndRequest(rp *requestParam) (*responseWriter, *http.Request, error) { sc.serveG.check() - rp := &sc.req if rp.invalidHeader || rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { // See 8.1.2.6 Malformed Requests and Responses: @@ -1548,6 +1641,7 @@ var ( _ http.CloseNotifier = (*responseWriter)(nil) _ http.Flusher = (*responseWriter)(nil) _ stringWriter = (*responseWriter)(nil) + _ Pusher = (*responseWriter)(nil) ) type responseWriterState struct { @@ -1642,6 +1736,46 @@ func (w *responseWriter) Flush() { } } +type pushPromise struct { + reqpm requestParam + done chan error +} + +// Pusher is an interface that http.ResponseWriters implement if they support server push +// When Push is called, it will send a push promise to the client, using the method, path and +// headers. It will then initiate the push by calling the server-level http.Handler +// for the path. +type Pusher interface { + Push(method string, path string, h http.Header) error +} + +func (w *responseWriter) Push(method string, path string, h http.Header) error { + rws := w.rws + if rws == nil { + panic("Push called after Handler finished") + } + if h == nil { + h = rws.req.Header + } + switch method { + case "GET", "HEAD": + default: + return errors.New("http2: invalid method for push promise") + } + pp := pushPromise{ + reqpm: requestParam{ + stream: rws.stream, + header: cloneHeader(h), + method: method, + path: path, + scheme: "https", + authority: rws.req.Host, + }, + done: rws.frameWriteCh, + } + return rws.conn.pushPromise(&pp) +} + func (w *responseWriter) CloseNotify() <-chan bool { rws := w.rws if rws == nil { diff --git a/server_test.go b/server_test.go index 6a3d63df..dcccfc96 100644 --- a/server_test.go +++ b/server_test.go @@ -422,6 +422,18 @@ func (st *serverTester) wantSettingsAck() { } +func (st *serverTester) wantPushPromise() *PushPromiseFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatal(err) + } + ppf, ok := f.(*PushPromiseFrame) + if !ok { + st.t.Fatalf("Wanted PushPromise, received %T", ppf) + } + return ppf +} + func TestServer(t *testing.T) { gotReq := make(chan bool, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -1767,6 +1779,31 @@ func TestServer_Response_Automatic100Continue(t *testing.T) { }) } +func TestServer_Response_PushPromise(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + if r.URL.Path == "/" { + p := w.(Pusher) + p.Push("GET", "/push", nil) + } + }) + st.greet() + getSlash(st) + ppf := st.wantPushPromise() + if !ppf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + for i := 0; i < 2; i++ { + hf := st.wantHeaders() + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + if !hf.StreamEnded() { + t.Fatal("want END_STREAM flag") + } + } +} + func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) { errc := make(chan error, 1) testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { diff --git a/write.go b/write.go index 7b9bdd3a..c9f7a1cd 100644 --- a/write.go +++ b/write.go @@ -142,7 +142,37 @@ func (w *writeResHeaders) writeFrame(ctx writeContext) error { if len(headerBlock) == 0 { panic("unexpected empty hpack") } + return writeHeaders(ctx, w.streamID, 0, w.endStream, headerBlock) +} +// writePPHeaders is a request to write a Push Promise and 0+ CONTINUATION frames +// TODO(dmorsing): when we have a client, reuse this to code to write requests +type writePPHeaders struct { + streamID uint32 + reqpm *requestParam +} + +func (w *writePPHeaders) writeFrame(ctx writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":method", Value: w.reqpm.method}) + enc.WriteField(hpack.HeaderField{Name: ":path", Value: w.reqpm.path}) + enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: w.reqpm.scheme}) + enc.WriteField(hpack.HeaderField{Name: ":authority", Value: w.reqpm.authority}) + for k, vv := range w.reqpm.header { + k = lowerHeader(k) + for _, v := range vv { + enc.WriteField(hpack.HeaderField{Name: k, Value: v}) + } + } + headerBlock := buf.Bytes() + if len(headerBlock) == 0 { + panic("unexpected empty hpack") + } + return writeHeaders(ctx, w.streamID, w.reqpm.stream.id, true, headerBlock) +} + +func writeHeaders(ctx writeContext, streamid uint32, promiseid uint32, endStream bool, headerBlock []byte) error { // For now we're lazy and just pick the minimum MAX_FRAME_SIZE // that all peers must support (16KB). Later we could care // more and send larger frames if the peer advertised it, but @@ -162,14 +192,19 @@ func (w *writeResHeaders) writeFrame(ctx writeContext) error { var err error if first { first = false - err = ctx.Framer().WriteHeaders(HeadersFrameParam{ - StreamID: w.streamID, + hfp := HeadersFrameParam{ + StreamID: streamid, BlockFragment: frag, - EndStream: w.endStream, + EndStream: endStream, EndHeaders: endHeaders, - }) + } + if promiseid != 0 { + hfp.PushPromise = true + hfp.PromiseID = promiseid + } + err = ctx.Framer().WriteHeaders(hfp) } else { - err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag) + err = ctx.Framer().WriteContinuation(streamid, endHeaders, frag) } if err != nil { return err