diff --git a/http2.go b/http2.go index 3ed51a44e8..4c5e11ac2c 100644 --- a/http2.go +++ b/http2.go @@ -17,6 +17,7 @@ package http2 import ( "bufio" + "crypto/tls" "errors" "fmt" "io" @@ -422,3 +423,7 @@ var isTokenTable = [127]bool{ '|': true, '~': true, } + +type connectionStater interface { + ConnectionState() tls.ConnectionState +} diff --git a/server.go b/server.go index f24a8e828c..6f4c2bb7cd 100644 --- a/server.go +++ b/server.go @@ -195,28 +195,76 @@ func ConfigureServer(s *http.Server, conf *Server) error { if testHookOnConn != nil { testHookOnConn() } - conf.handleConn(hs, c, h) + conf.ServeConn(c, &ServeConnOpts{ + Handler: h, + BaseConfig: hs, + }) } s.TLSNextProto[NextProtoTLS] = protoHandler s.TLSNextProto["h2-14"] = protoHandler // temporary; see above. return nil } -func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) { +// ServeConnOpts are options for the Server.ServeConn method. +type ServeConnOpts struct { + // BaseConfig optionally sets the base configuration + // for values. If nil, defaults are used. + BaseConfig *http.Server + + // Handler specifies which handler to use for processing + // requests. If nil, BaseConfig.Handler is used. If BaseConfig + // or BaseConfig.Handler is nil, http.DefaultServeMux is used. + Handler http.Handler +} + +func (o *ServeConnOpts) baseConfig() *http.Server { + if o != nil && o.BaseConfig != nil { + return o.BaseConfig + } + return new(http.Server) +} + +func (o *ServeConnOpts) handler() http.Handler { + if o != nil { + if o.Handler != nil { + return o.Handler + } + if o.BaseConfig != nil && o.BaseConfig.Handler != nil { + return o.BaseConfig.Handler + } + } + return http.DefaultServeMux +} + +// ServeConn serves HTTP/2 requests on the provided connection and +// blocks until the connection is no longer readable. +// +// ServeConn starts speaking HTTP/2 assuming that c has not had any +// reads or writes. It writes its initial settings frame and expects +// to be able to read the preface and settings frame from the +// client. If c has a ConnectionState method like a *tls.Conn, the +// ConnectionState is used to verify the TLS ciphersuite and to set +// the Request.TLS field in Handlers. +// +// ServeConn does not support h2c by itself. Any h2c support must be +// implemented in terms of providing a suitably-behaving net.Conn. +// +// The opts parameter is optional. If nil, default values are used. +func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { sc := &serverConn{ - srv: srv, - hs: hs, + srv: s, + hs: opts.baseConfig(), conn: c, remoteAddrStr: c.RemoteAddr().String(), bw: newBufferedWriter(c), - handler: h, + handler: opts.handler(), streams: make(map[uint32]*stream), readFrameCh: make(chan readFrameResult), wantWriteFrameCh: make(chan frameWriteMsg, 8), wroteFrameCh: make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way doneServing: make(chan struct{}), - advMaxStreams: srv.maxConcurrentStreams(), + advMaxStreams: s.maxConcurrentStreams(), writeSched: writeScheduler{ maxFrameSize: initialMaxFrameSize, }, @@ -232,10 +280,10 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) { sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen()) fr := NewFramer(sc.bw, c) - fr.SetMaxReadFrameSize(srv.maxReadFrameSize()) + fr.SetMaxReadFrameSize(s.maxReadFrameSize()) sc.framer = fr - if tc, ok := c.(*tls.Conn); ok { + if tc, ok := c.(connectionStater); ok { sc.tlsState = new(tls.ConnectionState) *sc.tlsState = tc.ConnectionState() // 9.2 Use of TLS Features @@ -265,7 +313,7 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) { // So for now, do nothing here again. } - if !srv.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) { + if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) { // "Endpoints MAY choose to generate a connection error // (Section 5.4.1) of type INADEQUATE_SECURITY if one of // the prohibited cipher suites are negotiated." diff --git a/server_test.go b/server_test.go index 0bde49698d..2d1578e7c0 100644 --- a/server_test.go +++ b/server_test.go @@ -2808,12 +2808,16 @@ func TestIssue53(t *testing.T) { "\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad" s := &http.Server{ ErrorLog: log.New(io.MultiWriter(stderrv(), twriter{t: t}), "", log.LstdFlags), + Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte("hello")) + }), + } + s2 := &Server{ + MaxReadFrameSize: 1 << 16, + PermitProhibitedCipherSuites: true, } - s2 := &Server{MaxReadFrameSize: 1 << 16, PermitProhibitedCipherSuites: true} c := &issue53Conn{[]byte(data), false, false} - s2.handleConn(s, c, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("hello")) - })) + s2.ServeConn(c, &ServeConnOpts{BaseConfig: s}) if !c.closed { t.Fatal("connection is not closed") } @@ -2977,3 +2981,102 @@ func TestServerNoDuplicateContentType(t *testing.T) { t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want) } } + +type connStateConn struct { + net.Conn + cs tls.ConnectionState +} + +func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs } + +// golang.org/issue/12737 -- handle any net.Conn, not just +// *tls.Conn. +func TestServerHandleCustomConn(t *testing.T) { + var s Server + c1, c2 := net.Pipe() + clientDone := make(chan struct{}) + handlerDone := make(chan struct{}) + var req *http.Request + go func() { + defer close(clientDone) + defer c2.Close() + fr := NewFramer(c2, c2) + io.WriteString(c2, ClientPreface) + fr.WriteSettings() + fr.WriteSettingsAck() + f, err := fr.ReadFrame() + if err != nil { + t.Error(err) + return + } + if sf, ok := f.(*SettingsFrame); !ok || sf.IsAck() { + t.Errorf("Got %v; want non-ACK SettingsFrame", summarizeFrame(f)) + return + } + f, err = fr.ReadFrame() + if err != nil { + t.Error(err) + return + } + if sf, ok := f.(*SettingsFrame); !ok || !sf.IsAck() { + t.Errorf("Got %v; want ACK SettingsFrame", summarizeFrame(f)) + return + } + var henc hpackEncoder + fr.WriteHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: henc.encodeHeaderRaw(t, ":method", "GET", ":path", "/", ":scheme", "https", ":authority", "foo.com"), + EndStream: true, + EndHeaders: true, + }) + go io.Copy(ioutil.Discard, c2) + <-handlerDone + }() + const testString = "my custom ConnectionState" + fakeConnState := tls.ConnectionState{ + ServerName: testString, + Version: tls.VersionTLS12, + } + go s.ServeConn(connStateConn{c1, fakeConnState}, &ServeConnOpts{ + BaseConfig: &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(handlerDone) + req = r + }), + }}) + select { + case <-clientDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for handler") + } + if req.TLS == nil { + t.Fatalf("Request.TLS is nil. Got: %#v", req) + } + if req.TLS.ServerName != testString { + t.Fatalf("Request.TLS = %+v; want ServerName of %q", req.TLS, testString) + } +} + +type hpackEncoder struct { + enc *hpack.Encoder + buf bytes.Buffer +} + +func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte { + if len(headers)%2 == 1 { + panic("odd number of kv args") + } + he.buf.Reset() + if he.enc == nil { + he.enc = hpack.NewEncoder(&he.buf) + } + for len(headers) > 0 { + k, v := headers[0], headers[1] + err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v}) + if err != nil { + t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) + } + headers = headers[2:] + } + return he.buf.Bytes() +} diff --git a/transport.go b/transport.go index 6c4377de8b..8e2e5f2cf3 100644 --- a/transport.go +++ b/transport.go @@ -406,9 +406,6 @@ func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { // henc in response to SETTINGS frames? cc.henc = hpack.NewEncoder(&cc.hbuf) - type connectionStater interface { - ConnectionState() tls.ConnectionState - } if cs, ok := c.(connectionStater); ok { state := cs.ConnectionState() cc.tlsState = &state