Skip to content

Commit

Permalink
Responded to offline discussion
Browse files Browse the repository at this point in the history
  • Loading branch information
zasweq committed Oct 23, 2023
1 parent 1fd9e20 commit 8a70332
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 29 deletions.
33 changes: 20 additions & 13 deletions internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,25 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
return nil, errors.New(msg)
}

var localAddr net.Addr
if la := r.Context().Value(http.LocalAddrContextKey); la != nil {
localAddr = la.(net.Addr)
}
var authInfo credentials.AuthInfo
if r.TLS != nil {
authInfo = credentials.TLSInfo{State: *r.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
}
p := peer.Peer{
Addr: strAddr(r.RemoteAddr),
LocalAddr: localAddr,
AuthInfo: authInfo,
}
st := &serverHandlerTransport{
rw: w,
req: r,
closedCh: make(chan struct{}),
writes: make(chan func()),
peer: p,
contentType: contentType,
contentSubtype: contentSubtype,
stats: stats,
Expand Down Expand Up @@ -134,6 +148,8 @@ type serverHandlerTransport struct {

headerMD metadata.MD

peer peer.Peer

closeOnce sync.Once
closedCh chan struct{} // closed on Close

Expand Down Expand Up @@ -166,13 +182,10 @@ func (ht *serverHandlerTransport) Close(err error) {
}

func (ht *serverHandlerTransport) Peer() *peer.Peer {
var localAddr net.Addr
if la := ht.req.Context().Value(http.LocalAddrContextKey); la != nil {
localAddr = la.(net.Addr)
}
return &peer.Peer{
Addr: strAddr(ht.req.RemoteAddr),
LocalAddr: localAddr,
Addr: ht.peer.Addr,
LocalAddr: ht.peer.LocalAddr,
AuthInfo: ht.peer.AuthInfo,
}
}

Expand Down Expand Up @@ -356,9 +369,8 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return err
}

func (ht *serverHandlerTransport) HandleStreams(_ context.Context, startStream func(*Stream)) {
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) {
// With this transport type there will be exactly 1 stream: this HTTP request.
ctx := ht.req.Context()
var cancel context.CancelFunc
if ht.timeoutSet {
ctx, cancel = context.WithTimeout(ctx, ht.timeout)
Expand Down Expand Up @@ -390,12 +402,7 @@ func (ht *serverHandlerTransport) HandleStreams(_ context.Context, startStream f
contentSubtype: ht.contentSubtype,
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
}
pr := ht.Peer()
if req.TLS != nil {
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
}
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
s.ctx = peer.NewContext(ctx, pr)
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
windowHandler: func(int) {},
Expand Down
27 changes: 15 additions & 12 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ type http2Server struct {
readerDone chan struct{} // sync point to enable testing.
writerDone chan struct{} // sync point to enable testing.
remoteAddr net.Addr
localAddr net.Addr
authInfo credentials.AuthInfo // auth info about the connection
peer peer.Peer
inTapHandle tap.ServerInHandle
framer *framer
// The max number of concurrent streams.
Expand Down Expand Up @@ -242,12 +241,15 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
}

done := make(chan struct{})
peer := peer.Peer{
Addr: conn.RemoteAddr(),
LocalAddr: conn.LocalAddr(),
AuthInfo: authInfo,
}
t := &http2Server{
done: done,
conn: conn,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
authInfo: authInfo,
peer: peer,
framer: framer,
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
Expand All @@ -273,7 +275,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
updateFlowControl: t.updateFlowControl,
}
}
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.peer.Addr, t.peer.LocalAddr))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1292,6 +1294,7 @@ func (t *http2Server) Drain(debugData string) {
}
t.drainEvent = grpcsync.NewEvent()
t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte(debugData), headsUp: true})
t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte(debugData), headsUp: true})
}

var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}}
Expand Down Expand Up @@ -1366,11 +1369,11 @@ func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric {
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
LocalFlowControlWindow: int64(t.fc.getSize()),
SocketOptions: channelz.GetSocketOption(t.conn),
LocalAddr: t.localAddr,
RemoteAddr: t.remoteAddr,
LocalAddr: t.peer.LocalAddr,
RemoteAddr: t.peer.Addr,
// RemoteName :
}
if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok {
if au, ok := t.peer.AuthInfo.(credentials.ChannelzSecurityInfo); ok {
s.Security = au.GetSecurityValue()
}
s.RemoteFlowControlWindow = t.getOutFlowWindow()
Expand Down Expand Up @@ -1405,9 +1408,9 @@ func (t *http2Server) getOutFlowWindow() int64 {
// Peer returns the peer of the transport.
func (t *http2Server) Peer() *peer.Peer {
return &peer.Peer{
Addr: t.remoteAddr,
LocalAddr: t.localAddr,
AuthInfo: t.authInfo, // Can be nil
Addr: t.peer.Addr,
LocalAddr: t.peer.LocalAddr,
AuthInfo: t.peer.AuthInfo,
}
}

Expand Down
8 changes: 4 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
return
}
go func() {
s.serveStreams(st, rawConn)
s.serveStreams(context.Background(), st, rawConn)
s.removeConn(lisAddr, st)
}()
}
Expand Down Expand Up @@ -987,8 +987,8 @@ func setConnection(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connectionKey{}, conn)
}

func (s *Server) serveStreams(st transport.ServerTransport, rawConn net.Conn) {
ctx := setConnection(context.Background(), rawConn)
func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) {
ctx = setConnection(ctx, rawConn)
ctx = peer.NewContext(ctx, st.Peer())
for _, sh := range s.opts.statsHandlers {
ctx = sh.TagConn(ctx, &stats.ConnTagInfo{
Expand Down Expand Up @@ -1071,7 +1071,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
defer s.removeConn(listenerAddressForServeHTTP, st)
s.serveStreams(st, nil)
s.serveStreams(r.Context(), st, nil)
}

func (s *Server) addConn(addr string, st transport.ServerTransport) bool {
Expand Down

0 comments on commit 8a70332

Please sign in to comment.