Skip to content

Commit

Permalink
Add a ServeHTTP method to *grpc.Server
Browse files Browse the repository at this point in the history
This adds new http.Handler-based ServerTransport in the process,
reusing the HTTP/2 server code in x/net/http2 or Go 1.6+.

All end2end tests pass with this new ServerTransport.

Fixes grpc#75

Also:
Updates grpc#495 (lets user fix it with middleware in front)
Updates grpc#468 (x/net/http2 validates)
Updates grpc#147 (possible with x/net/http2)
Updates grpc#104 (x/net/http2 does this)
  • Loading branch information
bradfitz committed Feb 9, 2016
1 parent 16885aa commit d1f014c
Show file tree
Hide file tree
Showing 7 changed files with 632 additions and 46 deletions.
2 changes: 1 addition & 1 deletion rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
case compressionNone:
case compressionMade:
if recvCompress == "" {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
return transport.StreamErrorf(codes.InvalidArgument, "grpc: invalid grpc-encoding %q with compression enabled", recvCompress)
}
if dc == nil || recvCompress != dc.Type() {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
Expand Down
172 changes: 141 additions & 31 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ import (
"fmt"
"io"
"net"
"net/http"
"reflect"
"runtime"
"strings"
"sync"
"time"

"golang.org/x/net/context"
"golang.org/x/net/http2"
"golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -82,10 +84,11 @@ type service struct {

// Server is a gRPC server to serve RPC requests.
type Server struct {
opts options
mu sync.Mutex
opts options

mu sync.Mutex // guards following
lis map[net.Listener]bool
conns map[transport.ServerTransport]bool
conns map[io.Closer]bool
m map[string]*service // service name -> service info
events trace.EventLog
}
Expand All @@ -96,6 +99,7 @@ type options struct {
cp Compressor
dc Decompressor
maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server
}

// A ServerOption sets options.
Expand Down Expand Up @@ -149,7 +153,7 @@ func NewServer(opt ...ServerOption) *Server {
s := &Server{
lis: make(map[net.Listener]bool),
opts: opts,
conns: make(map[transport.ServerTransport]bool),
conns: make(map[io.Closer]bool),
m: make(map[string]*service),
}
if EnableTracing {
Expand Down Expand Up @@ -216,9 +220,17 @@ var (
ErrServerStopped = errors.New("grpc: the server has been stopped")
)

func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
creds, ok := s.opts.creds.(credentials.TransportAuthenticator)
if !ok {
return rawConn, nil, nil
}
return creds.ServerHandshake(rawConn)
}

// Serve accepts incoming connections on the listener lis, creating a new
// ServerTransport and service goroutine for each. The service goroutines
// read gRPC request and then call the registered handlers to reply to them.
// read gRPC requests and then call the registered handlers to reply to them.
// Service returns when lis.Accept fails.
func (s *Server) Serve(lis net.Listener) error {
s.mu.Lock()
Expand All @@ -235,39 +247,54 @@ func (s *Server) Serve(lis net.Listener) error {
delete(s.lis, lis)
s.mu.Unlock()
}()
listenerAddr := lis.Addr()
for {
c, err := lis.Accept()
rawConn, err := lis.Accept()
if err != nil {
s.mu.Lock()
s.printf("done serving; Accept = %v", err)
s.mu.Unlock()
return err
}
var authInfo credentials.AuthInfo
if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok {
var conn net.Conn
conn, authInfo, err = creds.ServerHandshake(c)
if err != nil {
s.mu.Lock()
s.errorf("ServerHandshake(%q) failed: %v", c.RemoteAddr(), err)
s.mu.Unlock()
grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
continue
}
c = conn
}
// Start a new goroutine to deal with rawConn
// so we don't stall this Accept loop goroutine.
go s.handleRawConn(listenerAddr, rawConn)
}
}

// handleRawConn is run in its own goroutine and handles a just-accepted
// connection that has not had any I/O performed on it yet.
func (s *Server) handleRawConn(listenerAddr net.Addr, rawConn net.Conn) {
conn, authInfo, err := s.useTransportAuthenticator(rawConn)
if err != nil {
s.mu.Lock()
if s.conns == nil {
s.mu.Unlock()
c.Close()
return nil
}
s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
s.mu.Unlock()
grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
rawConn.Close()
return
}

go s.serveNewHTTP2Transport(c, authInfo)
s.mu.Lock()
if s.conns == nil {
s.mu.Unlock()
conn.Close()
return
}
s.mu.Unlock()

if s.opts.useHandlerImpl {
s.serveUsingHandler(listenerAddr, conn)
} else {
s.serveNewHTTP2Transport(conn, authInfo)
}
}

// serveNewHTTP2Transport sets up a new http/2 transport (using the
// gRPC http2 server transport in transport/http2_server.go) and
// serves streams on it.
// This is run in its own goroutine (it does network I/O in
// transport.NewServerTransport).
func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
if err != nil {
Expand Down Expand Up @@ -299,6 +326,52 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
wg.Wait()
}

var _ http.Handler = (*Server)(nil)

// serveUsingHandler is the implementation of Serve(net.Listener) when
// TestingUseHandlerImpl has been configured. This lets the end2end
// tests exercise the ServeHTTP method as one of the environment types.
//
// conn is the *tls.Conn that's already been authenticated.
func (s *Server) serveUsingHandler(listenerAddr net.Addr, conn net.Conn) {
if !s.addConn(conn) {
conn.Close()
return
}
defer s.removeConn(conn)
connDone := make(chan struct{})
hs := &http.Server{
Handler: s,
ConnState: func(c net.Conn, cs http.ConnState) {
if cs == http.StateClosed {
close(connDone)
}
},
}
if err := http2.ConfigureServer(hs, &http2.Server{
MaxConcurrentStreams: s.opts.maxConcurrentStreams,
}); err != nil {
grpclog.Fatalf("grpc: http2.ConfigureServer: %v", err)
return
}
hs.Serve(&singleConnListener{addr: listenerAddr, conn: conn})
<-connDone
}

func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
st, err := transport.NewServerHandlerTransport(w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if !s.addConn(st) {
st.Close()
return
}
defer s.removeConn(st)
s.serveStreams(st)
}

// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled.
// If tracing is not enabled, it returns nil.
func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) {
Expand All @@ -317,21 +390,21 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
return trInfo
}

func (s *Server) addConn(st transport.ServerTransport) bool {
func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.conns == nil {
return false
}
s.conns[st] = true
s.conns[c] = true
return true
}

func (s *Server) removeConn(st transport.ServerTransport) {
func (s *Server) removeConn(c io.Closer) {
s.mu.Lock()
defer s.mu.Unlock()
if s.conns != nil {
delete(s.conns, st)
delete(s.conns, c)
}
}

Expand Down Expand Up @@ -603,12 +676,14 @@ func (s *Server) Stop() {
cs := s.conns
s.conns = nil
s.mu.Unlock()

for lis := range listeners {
lis.Close()
}
for c := range cs {
c.Close()
}

s.mu.Lock()
if s.events != nil {
s.events.Finish()
Expand All @@ -618,16 +693,24 @@ func (s *Server) Stop() {
}

// TestingCloseConns closes all exiting transports but keeps s.lis accepting new
// connections. This is for test only now.
// connections.
// This is only for tests and is subject to removal.
func (s *Server) TestingCloseConns() {
s.mu.Lock()
for c := range s.conns {
c.Close()
delete(s.conns, c)
}
s.conns = make(map[transport.ServerTransport]bool)
s.mu.Unlock()
}

// TestingUseHandlerImpl enables the http.Handler-based server implementation.
// It must be called before Serve and requires TLS credentials.
// This is only for tests and is subject to removal.
func (s *Server) TestingUseHandlerImpl() {
s.opts.useHandlerImpl = true
}

// SendHeader sends header metadata. It may be called at most once from a unary
// RPC handler. The ctx is the RPC handler's Context or one derived from it.
func SendHeader(ctx context.Context, md metadata.MD) error {
Expand Down Expand Up @@ -658,3 +741,30 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
}
return stream.SetTrailer(md)
}

// singleConnListener is a net.Listener that yields a single conn.
type singleConnListener struct {
mu sync.Mutex
addr net.Addr
conn net.Conn // nil if done
}

func (ln *singleConnListener) Addr() net.Addr { return ln.addr }

func (ln *singleConnListener) Close() error {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.conn = nil
return nil
}

func (ln *singleConnListener) Accept() (net.Conn, error) {
ln.mu.Lock()
defer ln.mu.Unlock()
c := ln.conn
if c == nil {
return nil, io.EOF
}
ln.conn = nil
return c, nil
}
Loading

0 comments on commit d1f014c

Please sign in to comment.