Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grpc: Move some stats handler calls to gRPC layer, and add local address to peer.Peer #6716

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ var (
// gRPC server. An xDS-enabled server needs to know what type of credentials
// is configured on the underlying gRPC server. This is set by server.go.
GetServerCredentials any // func (*grpc.Server) credentials.TransportCredentials
// GetConnection gets the connection from the context.
GetConnection any // func (context.Context) net.Conn
// CanonicalString returns the canonical string of the code defined here:
// https://github.com/grpc/grpc/blob/master/doc/statuscodes.md.
//
Expand Down
36 changes: 22 additions & 14 deletions internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,25 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
return nil, errors.New(msg)
}

var localAddr net.Addr
if la := r.Context().Value(http.LocalAddrContextKey); la != nil {
localAddr = la.(net.Addr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will panic if the wrong type is here. localAddr, _ = ... is safer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
var authInfo credentials.AuthInfo
if r.TLS != nil {
authInfo = credentials.TLSInfo{State: *r.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
}
p := peer.Peer{
Addr: strAddr(r.RemoteAddr),
LocalAddr: localAddr,
AuthInfo: authInfo,
}
st := &serverHandlerTransport{
rw: w,
req: r,
closedCh: make(chan struct{}),
writes: make(chan func()),
peer: p,
contentType: contentType,
contentSubtype: contentSubtype,
stats: stats,
Expand Down Expand Up @@ -134,6 +148,8 @@ type serverHandlerTransport struct {

headerMD metadata.MD

peer peer.Peer

closeOnce sync.Once
closedCh chan struct{} // closed on Close

Expand Down Expand Up @@ -166,13 +182,10 @@ func (ht *serverHandlerTransport) Close(err error) {
}

func (ht *serverHandlerTransport) Peer() *peer.Peer {
var localAddr net.Addr
if la := ht.req.Context().Value(http.LocalAddrContextKey); la != nil {
localAddr = la.(net.Addr)
}
return &peer.Peer{
Addr: strAddr(ht.req.RemoteAddr),
LocalAddr: localAddr,
Addr: ht.peer.Addr,
LocalAddr: ht.peer.LocalAddr,
AuthInfo: ht.peer.AuthInfo,
}
}

Expand Down Expand Up @@ -356,9 +369,8 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return err
}

func (ht *serverHandlerTransport) HandleStreams(_ context.Context, startStream func(*Stream)) {
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) {
// With this transport type there will be exactly 1 stream: this HTTP request.
ctx := ht.req.Context()
var cancel context.CancelFunc
if ht.timeoutSet {
ctx, cancel = context.WithTimeout(ctx, ht.timeout)
Expand All @@ -378,9 +390,11 @@ func (ht *serverHandlerTransport) HandleStreams(_ context.Context, startStream f
ht.Close(errors.New("request is done processing"))
}()

ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
req := ht.req
s := &Stream{
id: 0, // irrelevant
ctx: ctx,
requestRead: func(int) {},
cancel: cancel,
buf: newRecvBuffer(),
Expand All @@ -390,12 +404,6 @@ func (ht *serverHandlerTransport) HandleStreams(_ context.Context, startStream f
contentSubtype: ht.contentSubtype,
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
}
pr := ht.Peer()
if req.TLS != nil {
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
}
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
s.ctx = peer.NewContext(ctx, pr)
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
windowHandler: func(int) {},
Expand Down
42 changes: 29 additions & 13 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ type http2Server struct {
loopy *loopyWriter
readerDone chan struct{} // sync point to enable testing.
writerDone chan struct{} // sync point to enable testing.
remoteAddr net.Addr
localAddr net.Addr
authInfo credentials.AuthInfo // auth info about the connection
peer peer.Peer
inTapHandle tap.ServerInHandle
framer *framer
// The max number of concurrent streams.
Expand Down Expand Up @@ -242,12 +240,15 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
}

done := make(chan struct{})
peer := peer.Peer{
Addr: conn.RemoteAddr(),
LocalAddr: conn.LocalAddr(),
AuthInfo: authInfo,
}
t := &http2Server{
done: done,
conn: conn,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
authInfo: authInfo,
peer: peer,
framer: framer,
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
Expand All @@ -273,7 +274,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
updateFlowControl: t.updateFlowControl,
}
}
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.peer.Addr, t.peer.LocalAddr))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1366,11 +1367,11 @@ func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric {
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
LocalFlowControlWindow: int64(t.fc.getSize()),
SocketOptions: channelz.GetSocketOption(t.conn),
LocalAddr: t.localAddr,
RemoteAddr: t.remoteAddr,
LocalAddr: t.peer.LocalAddr,
RemoteAddr: t.peer.Addr,
// RemoteName :
}
if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok {
if au, ok := t.peer.AuthInfo.(credentials.ChannelzSecurityInfo); ok {
s.Security = au.GetSecurityValue()
}
s.RemoteFlowControlWindow = t.getOutFlowWindow()
Expand Down Expand Up @@ -1405,9 +1406,9 @@ func (t *http2Server) getOutFlowWindow() int64 {
// Peer returns the peer of the transport.
func (t *http2Server) Peer() *peer.Peer {
return &peer.Peer{
Addr: t.remoteAddr,
LocalAddr: t.localAddr,
AuthInfo: t.authInfo, // Can be nil
Addr: t.peer.Addr,
LocalAddr: t.peer.LocalAddr,
AuthInfo: t.peer.AuthInfo, // Can be nil
}
}

Expand All @@ -1420,3 +1421,18 @@ func getJitter(v time.Duration) time.Duration {
j := grpcrand.Int63n(2*r) - r
return time.Duration(j)
}

type connectionKey struct{}

// GetConnection gets the connection from the context.
func GetConnection(ctx context.Context) net.Conn {
conn, _ := ctx.Value(connectionKey{}).(net.Conn)
return conn
}

// SetConnection adds the connection to the context to be able to get
// information about the destination ip and port for an incoming RPC. This also
// allows any unary or streaming interceptors to see the connection.
func SetConnection(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connectionKey{}, conn)
}
7 changes: 4 additions & 3 deletions internal/xds/rbac/rbac_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,22 @@ import (
"net"
"strconv"

v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3"
"google.golang.org/grpc"
"google.golang.org/grpc/authz/audit"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"

v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3"
)

var logger = grpclog.Component("rbac")

var getConnection = internal.GetConnection.(func(ctx context.Context) net.Conn)
var getConnection = transport.GetConnection

// ChainEngine represents a chain of RBAC Engines, used to make authorization
// decisions on incoming RPCs.
Expand Down
25 changes: 5 additions & 20 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
}
internal.BinaryLogger = binaryLogger
internal.JoinServerOptions = newJoinServerOption
internal.GetConnection = getConnection
internal.RecvBufferPool = recvBufferPool
}

var statusOK = status.New(codes.OK, "")
Expand Down Expand Up @@ -920,7 +920,7 @@
return
}
go func() {
s.serveStreams(st, rawConn)
s.serveStreams(context.Background(), st, rawConn)
s.removeConn(lisAddr, st)
}()
}
Expand Down Expand Up @@ -974,23 +974,8 @@
return st
}

type connectionKey struct{}

// getConnection gets the connection from the context.
func getConnection(ctx context.Context) net.Conn {
conn, _ := ctx.Value(connectionKey{}).(net.Conn)
return conn
}

// setConnection adds the connection to the context to be able to get
// information about the destination ip and port for an incoming RPC. This also
// allows any unary or streaming interceptors to see the connection.
func setConnection(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connectionKey{}, conn)
}

func (s *Server) serveStreams(st transport.ServerTransport, rawConn net.Conn) {
ctx := setConnection(context.Background(), rawConn)
func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) {
ctx = transport.SetConnection(ctx, rawConn)
ctx = peer.NewContext(ctx, st.Peer())
for _, sh := range s.opts.statsHandlers {
ctx = sh.TagConn(ctx, &stats.ConnTagInfo{
Expand Down Expand Up @@ -1073,7 +1058,7 @@
return
}
defer s.removeConn(listenerAddressForServeHTTP, st)
s.serveStreams(st, nil)
s.serveStreams(r.Context(), st, nil)
}

func (s *Server) addConn(addr string, st transport.ServerTransport) bool {
Expand Down Expand Up @@ -1730,7 +1715,7 @@
tr: tr,
firstLine: firstLine{
client: false,
remoteAddr: t.Peer().Addr,

Check warning on line 1718 in server.go

View check run for this annotation

Codecov / codecov/patch

server.go#L1718

Added line #L1718 was not covered by tests
},
}
if dl, ok := ctx.Deadline(); ok {
Expand Down
5 changes: 3 additions & 2 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import (
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
Expand Down Expand Up @@ -5828,7 +5829,7 @@ func (s) TestClientSettingsFloodCloseConn(t *testing.T) {
}

func unaryInterceptorVerifyConn(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
conn := internal.GetConnection.(func(context.Context) net.Conn)(ctx)
conn := transport.GetConnection(ctx)
if conn == nil {
return nil, status.Error(codes.NotFound, "connection was not in context")
}
Expand All @@ -5853,7 +5854,7 @@ func (s) TestUnaryServerInterceptorGetsConnection(t *testing.T) {
}

func streamingInterceptorVerifyConn(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
conn := internal.GetConnection.(func(context.Context) net.Conn)(ss.Context())
conn := transport.GetConnection(ss.Context())
if conn == nil {
return status.Error(codes.NotFound, "connection was not in context")
}
Expand Down
3 changes: 2 additions & 1 deletion xds/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcsync"
iresolver "google.golang.org/grpc/internal/resolver"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/xds/internal/server"
Expand Down Expand Up @@ -340,7 +341,7 @@ func (s *GRPCServer) GracefulStop() {
// table and also processes the RPC by running the incoming RPC through any HTTP
// Filters configured.
func routeAndProcess(ctx context.Context) error {
conn := internal.GetConnection.(func(context.Context) net.Conn)(ctx)
conn := transport.GetConnection(ctx)
cw, ok := conn.(interface {
VirtualHosts() []xdsresource.VirtualHostWithInterceptors
})
Expand Down
Loading