diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index eed0b360b840..0ae321ffde1d 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -75,11 +75,25 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s return nil, errors.New(msg) } + var localAddr net.Addr + if la := r.Context().Value(http.LocalAddrContextKey); la != nil { + localAddr = la.(net.Addr) + } + var authInfo credentials.AuthInfo + if r.TLS != nil { + authInfo = credentials.TLSInfo{State: *r.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}} + } + p := peer.Peer{ + Addr: strAddr(r.RemoteAddr), + LocalAddr: localAddr, + AuthInfo: authInfo, + } st := &serverHandlerTransport{ rw: w, req: r, closedCh: make(chan struct{}), writes: make(chan func()), + peer: p, contentType: contentType, contentSubtype: contentSubtype, stats: stats, @@ -134,6 +148,8 @@ type serverHandlerTransport struct { headerMD metadata.MD + peer peer.Peer + closeOnce sync.Once closedCh chan struct{} // closed on Close @@ -166,13 +182,10 @@ func (ht *serverHandlerTransport) Close(err error) { } func (ht *serverHandlerTransport) Peer() *peer.Peer { - var localAddr net.Addr - if la := ht.req.Context().Value(http.LocalAddrContextKey); la != nil { - localAddr = la.(net.Addr) - } return &peer.Peer{ - Addr: strAddr(ht.req.RemoteAddr), - LocalAddr: localAddr, + Addr: ht.peer.Addr, + LocalAddr: ht.peer.LocalAddr, + AuthInfo: ht.peer.AuthInfo, } } @@ -356,9 +369,8 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { return err } -func (ht *serverHandlerTransport) HandleStreams(_ context.Context, startStream func(*Stream)) { +func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) { // With this transport type there will be exactly 1 stream: this HTTP request. - ctx := ht.req.Context() var cancel context.CancelFunc if ht.timeoutSet { ctx, cancel = context.WithTimeout(ctx, ht.timeout) @@ -390,12 +402,7 @@ func (ht *serverHandlerTransport) HandleStreams(_ context.Context, startStream f contentSubtype: ht.contentSubtype, headerWireLength: 0, // won't have access to header wire length until golang/go#18997. } - pr := ht.Peer() - if req.TLS != nil { - pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}} - } ctx = metadata.NewIncomingContext(ctx, ht.headerMD) - s.ctx = peer.NewContext(ctx, pr) s.trReader = &transportReader{ reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}}, windowHandler: func(int) {}, diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 68af6f297459..922c3e0bd92d 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -75,8 +75,7 @@ type http2Server struct { readerDone chan struct{} // sync point to enable testing. writerDone chan struct{} // sync point to enable testing. remoteAddr net.Addr - localAddr net.Addr - authInfo credentials.AuthInfo // auth info about the connection + peer peer.Peer inTapHandle tap.ServerInHandle framer *framer // The max number of concurrent streams. @@ -242,12 +241,15 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, } done := make(chan struct{}) + peer := peer.Peer{ + Addr: conn.RemoteAddr(), + LocalAddr: conn.LocalAddr(), + AuthInfo: authInfo, + } t := &http2Server{ done: done, conn: conn, - remoteAddr: conn.RemoteAddr(), - localAddr: conn.LocalAddr(), - authInfo: authInfo, + peer: peer, framer: framer, readerDone: make(chan struct{}), writerDone: make(chan struct{}), @@ -273,7 +275,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, updateFlowControl: t.updateFlowControl, } } - t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr)) + t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.peer.Addr, t.peer.LocalAddr)) if err != nil { return nil, err } @@ -1292,6 +1294,7 @@ func (t *http2Server) Drain(debugData string) { } t.drainEvent = grpcsync.NewEvent() t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte(debugData), headsUp: true}) + t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte(debugData), headsUp: true}) } var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} @@ -1366,11 +1369,11 @@ func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric { LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)), LocalFlowControlWindow: int64(t.fc.getSize()), SocketOptions: channelz.GetSocketOption(t.conn), - LocalAddr: t.localAddr, - RemoteAddr: t.remoteAddr, + LocalAddr: t.peer.LocalAddr, + RemoteAddr: t.peer.Addr, // RemoteName : } - if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok { + if au, ok := t.peer.AuthInfo.(credentials.ChannelzSecurityInfo); ok { s.Security = au.GetSecurityValue() } s.RemoteFlowControlWindow = t.getOutFlowWindow() @@ -1405,9 +1408,9 @@ func (t *http2Server) getOutFlowWindow() int64 { // Peer returns the peer of the transport. func (t *http2Server) Peer() *peer.Peer { return &peer.Peer{ - Addr: t.remoteAddr, - LocalAddr: t.localAddr, - AuthInfo: t.authInfo, // Can be nil + Addr: t.peer.Addr, + LocalAddr: t.peer.LocalAddr, + AuthInfo: t.peer.AuthInfo, } } diff --git a/server.go b/server.go index 0b9447719e19..d79fbc4ca483 100644 --- a/server.go +++ b/server.go @@ -918,7 +918,7 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) { return } go func() { - s.serveStreams(st, rawConn) + s.serveStreams(context.Background(), st, rawConn) s.removeConn(lisAddr, st) }() } @@ -987,8 +987,8 @@ func setConnection(ctx context.Context, conn net.Conn) context.Context { return context.WithValue(ctx, connectionKey{}, conn) } -func (s *Server) serveStreams(st transport.ServerTransport, rawConn net.Conn) { - ctx := setConnection(context.Background(), rawConn) +func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) { + ctx = setConnection(ctx, rawConn) ctx = peer.NewContext(ctx, st.Peer()) for _, sh := range s.opts.statsHandlers { ctx = sh.TagConn(ctx, &stats.ConnTagInfo{ @@ -1071,7 +1071,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } defer s.removeConn(listenerAddressForServeHTTP, st) - s.serveStreams(st, nil) + s.serveStreams(r.Context(), st, nil) } func (s *Server) addConn(addr string, st transport.ServerTransport) bool {