Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grpc: Add a pointer of server to ctx passed into stats handler #6750

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ var (
// xDS-enabled server invokes this method on a grpc.Server when a particular
// listener moves to "not-serving" mode.
DrainServerTransports any // func(*grpc.Server, string)
// IsRegisteredMethod returns whether the passed in method is registered as
// a method on the server.
IsRegisteredMethod any // func(*grpc.Server, string)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This returns a bool right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right, whoops. Added.

// ServerFromContext returns the server from the context.
ServerFromContext any // func(context.Context) *grpc.Server
// AddGlobalServerOptions adds an array of ServerOption that will be
// effective globally for newly created servers. The priority will be: 1.
// user-provided; 2. this method; 3. default values.
Expand Down
72 changes: 72 additions & 0 deletions internal/stubstatshandler/stubstatshandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

// Package stubstatshandler is a stubbable implementation of
// google.golang.org/grpc/stats.Handler for testing purposes.
package stubstatshandler

import (
"context"

"google.golang.org/grpc/stats"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
)

// StubStatsHandler is a stats handler that is easy to customize within
// individual test cases.
type StubStatsHandler struct {
// Guarantees we satisfy this interface; panics if unimplemented methods are
// called.
testgrpc.TestServiceServer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you bundling a stats handler with a service implementation? That seems undesirable unless it's somehow important to keep the two things together. (And now that I've look at the rest of the code, it seems unused and unnecessary.)

Also, we should put this under testutils (either in that package or a subdirectory of it).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, whoops, this is leftover from the stubserver copy paste. Changed to embedding a stats.Handler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved file to within testutils as internal/testutils/stubstatshandler.go.


TagRPCF func(ctx context.Context, info *stats.RPCTagInfo) context.Context
HandleRPCF func(ctx context.Context, info stats.RPCStats)
TagConnF func(ctx context.Context, info *stats.ConnTagInfo) context.Context
HandleConnF func(ctx context.Context, info stats.ConnStats)
}

// TagRPC calls the StubStatsHandler's TagRPCF, if set.
func (ssh *StubStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
if ssh.TagRPCF != nil {
return ssh.TagRPCF(ctx, info)
}
return ctx
}

// HandleRPC calls the StubStatsHandler's HandleRPCF, if set.
func (ssh *StubStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) {
if ssh.HandleRPCF != nil {
ssh.HandleRPCF(ctx, rs)
}
}

// TagConn calls the StubStatsHandler's TagConnF, if set.
func (ssh *StubStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
if ssh.TagConnF != nil {
return ssh.TagConnF(ctx, info)
}
return ctx
}

// HandleConn calls the StubStatsHandler's HandleConnF, if set.
func (ssh *StubStatsHandler) HandleConn(ctx context.Context, cs stats.ConnStats) {
if ssh.HandleConnF != nil {
ssh.HandleConnF(ctx, cs)
}
}
43 changes: 43 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ func init() {
internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials {
return srv.opts.creds
}
internal.IsRegisteredMethod = func(srv *Server, method string) bool {
return srv.isRegisteredMethod(method)
}
internal.ServerFromContext = serverFromContext
internal.DrainServerTransports = func(srv *Server, addr string) {
srv.drainServerTransports(addr)
}
Expand Down Expand Up @@ -1707,6 +1711,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran

func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) {
ctx := stream.Context()
ctx = contextWithServer(ctx, s)
var ti *traceInfo
if EnableTracing {
tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method())
Expand Down Expand Up @@ -1953,6 +1958,44 @@ func (s *Server) getCodec(contentSubtype string) baseCodec {
return codec
}

type serverKey struct{}

// serverFromContext gets the Server from the context.
func serverFromContext(ctx context.Context) *Server {
s, _ := ctx.Value(serverKey{}).(*Server)
return s
}

// contextWithServer sets the Server in the context.
func contextWithServer(ctx context.Context, server *Server) context.Context {
return context.WithValue(ctx, serverKey{}, server)
}

// isRegisteredMethod returns whether the passed in method is registered as a
// method on the server. /service/method and service/method will match if the
// service and method are registered on the server.
func (s *Server) isRegisteredMethod(serviceMethod string) bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this very similar to the code that looks up the handler for a method? Can we reuse code somehow, or even just call it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is very similar, but I chose not to do it because this line is very specific: https://github.com/grpc/grpc-go/blob/master/server.go#L1731, and I thought it would be hard to generalize what's shared into a helper. If you feel strongly about this after reading the handleStream() code, I can try and pull out logic into helper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah please see what you can do. I think a method that returns (service string, method string, ok bool) would be good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried writing it, and it seems distinct enough logically to not warrant a helper. Your function signature doesn't distinguish between a. malformed method name (important here for tracing and logging: https://github.com/grpc/grpc-go/blob/master/server.go#L1732) b. registered unary rpc c. registered streaming name (important here: https://github.com/grpc/grpc-go/blob/master/server.go#L1770, to chose to go through processUnary or processStreaming). This helper's scope is not to log anything in traces or channelz, but to parse a :method header received from the wire and determine whether it's registered or not registered. I think it's sufficiently different to not warrant a new helper.

if serviceMethod != "" && serviceMethod[0] == '/' {
serviceMethod = serviceMethod[1:]
}
pos := strings.LastIndex(serviceMethod, "/")
if pos == -1 { // Invalid method name syntax.
return false
}
service := serviceMethod[:pos]
method := serviceMethod[pos+1:]
srv, knownService := s.services[service]
if knownService {
if _, ok := srv.methods[method]; ok {
return true
}
if _, ok := srv.streams[method]; ok {
return true
}
}
return false
}

// SetHeader sets the header metadata to be sent from the server to the client.
// The context provided must be the context passed to the server's handler.
//
Expand Down
79 changes: 79 additions & 0 deletions stats/stats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ package stats_test

import (
"context"
"errors"
"fmt"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/stubstatshandler"
"google.golang.org/grpc/internal/testutils"
"io"
"net"
"reflect"
Expand Down Expand Up @@ -1457,3 +1462,77 @@ func (s) TestMultipleServerStatsHandler(t *testing.T) {
t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4)
}
}

// TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler
// gets access to a Server on the server side, and thus the method that the
// server owns which specifies whether a method is made or not. The test sets up
// a server with a unary call and full duplex call configured, and makes an RPC.
// Within the stats handler, asking the server whether unary or duplex method
// names are registered should return true, and any other query should return
// false.
func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) {
errorCh := testutils.NewChannel()
stubStatsHandler := &stubstatshandler.StubStatsHandler{
TagRPCF: func(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
// OpenTelemetry instrumentation needs the passed in Server to determine if
// methods are registered in different handle calls in to record metrics.
// This tag RPC call context gets passed into every handle call, so can
// assert once here, since it maps to all the handle RPC calls that come
// after. These internal calls will be how the OpenTelemetry instrumentation
// component accesses this server and the subsequent helper on the server.
server := internal.ServerFromContext.(func(context.Context) *grpc.Server)(ctx)
if server == nil {
errorCh.Send("stats handler received ctx has no server present")
}
isRegisteredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)
// /s/m and s/m are valid.
if !isRegisteredMethod(server, "/grpc.testing.TestService/UnaryCall") {
errorCh.Send(errors.New("UnaryCall should be a registered method according to server"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't actually need the error channel for this. You can just do t.Errorf here. Then all you need to do is ensure the code actually was called. You can do this by closing a done, channel or with a waitgroup, etc. One benefit of this is that you'll see all the test case failures in a single run if multiple of these checks don't work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right, this was when it was a defined stats handler rather than a stub with methods declared in test. Thus, before I couldn't close on t *testing.T, and now I can :). Switched.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chose a waitGroup.

return ctx
}
if !isRegisteredMethod(server, "grpc.testing.TestService/FullDuplexCall") {
errorCh.Send(errors.New("FullDuplexCall should be a registered method according to server"))
return ctx
}
if isRegisteredMethod(server, "/grpc.testing.TestService/DoesNotExistCall") {
errorCh.Send(errors.New("DoesNotExistCall should not be a registered method according to server"))
return ctx
}
if isRegisteredMethod(server, "/unknownService/UnaryCall") {
errorCh.Send(errors.New("/unknownService/UnaryCall should not be a registered method according to server"))
return ctx
}
errorCh.Send(nil)
return ctx
},
}
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You never call this; either delete it or test it please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do test this. See line 1493. However, you are right that I never initiate a bidirectional streaming RPC from the test. Would you like me to add that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This closure being set or not doesn't change the test at line 1494. The method is registered because it's in the service descriptor. I'm fine with deleting this; I don't think there's any need to test both streaming and unary RPCs unless there are different code paths that lead to setting the server in the context, which there aren't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right, these closures != service descriptor, it's all there anyway. Deleting.

for {
if _, err := stream.Recv(); err == io.EOF {
return nil
}
}
},
}
if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(stubStatsHandler)}); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil {
t.Fatalf("Unexpected error from UnaryCall: %v", err)
}
err, errRecv := errorCh.Receive(ctx)
if errRecv != nil {
t.Fatalf("error receiving from channel: %v", errRecv)
}
if err != nil {
t.Fatalf("error received from error channel: %v", err)
}
}