-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
grpc: Add a pointer of server to ctx passed into stats handler #6750
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* | ||
* Copyright 2023 gRPC authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* | ||
*/ | ||
|
||
// Package stubstatshandler is a stubbable implementation of | ||
// google.golang.org/grpc/stats.Handler for testing purposes. | ||
package stubstatshandler | ||
|
||
import ( | ||
"context" | ||
|
||
"google.golang.org/grpc/stats" | ||
|
||
testgrpc "google.golang.org/grpc/interop/grpc_testing" | ||
) | ||
|
||
// StubStatsHandler is a stats handler that is easy to customize within | ||
// individual test cases. | ||
type StubStatsHandler struct { | ||
// Guarantees we satisfy this interface; panics if unimplemented methods are | ||
// called. | ||
testgrpc.TestServiceServer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you bundling a stats handler with a service implementation? That seems undesirable unless it's somehow important to keep the two things together. (And now that I've look at the rest of the code, it seems unused and unnecessary.) Also, we should put this under There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, whoops, this is leftover from the stubserver copy paste. Changed to embedding a stats.Handler. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved file to within testutils as internal/testutils/stubstatshandler.go. |
||
|
||
TagRPCF func(ctx context.Context, info *stats.RPCTagInfo) context.Context | ||
HandleRPCF func(ctx context.Context, info stats.RPCStats) | ||
TagConnF func(ctx context.Context, info *stats.ConnTagInfo) context.Context | ||
HandleConnF func(ctx context.Context, info stats.ConnStats) | ||
} | ||
|
||
// TagRPC calls the StubStatsHandler's TagRPCF, if set. | ||
func (ssh *StubStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { | ||
if ssh.TagRPCF != nil { | ||
return ssh.TagRPCF(ctx, info) | ||
} | ||
return ctx | ||
} | ||
|
||
// HandleRPC calls the StubStatsHandler's HandleRPCF, if set. | ||
func (ssh *StubStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { | ||
if ssh.HandleRPCF != nil { | ||
ssh.HandleRPCF(ctx, rs) | ||
} | ||
} | ||
|
||
// TagConn calls the StubStatsHandler's TagConnF, if set. | ||
func (ssh *StubStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { | ||
if ssh.TagConnF != nil { | ||
return ssh.TagConnF(ctx, info) | ||
} | ||
return ctx | ||
} | ||
|
||
// HandleConn calls the StubStatsHandler's HandleConnF, if set. | ||
func (ssh *StubStatsHandler) HandleConn(ctx context.Context, cs stats.ConnStats) { | ||
if ssh.HandleConnF != nil { | ||
ssh.HandleConnF(ctx, cs) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,10 @@ func init() { | |
internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials { | ||
return srv.opts.creds | ||
} | ||
internal.IsRegisteredMethod = func(srv *Server, method string) bool { | ||
return srv.isRegisteredMethod(method) | ||
} | ||
internal.ServerFromContext = serverFromContext | ||
internal.DrainServerTransports = func(srv *Server, addr string) { | ||
srv.drainServerTransports(addr) | ||
} | ||
|
@@ -1707,6 +1711,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran | |
|
||
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) { | ||
ctx := stream.Context() | ||
ctx = contextWithServer(ctx, s) | ||
var ti *traceInfo | ||
if EnableTracing { | ||
tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) | ||
|
@@ -1953,6 +1958,44 @@ func (s *Server) getCodec(contentSubtype string) baseCodec { | |
return codec | ||
} | ||
|
||
type serverKey struct{} | ||
|
||
// serverFromContext gets the Server from the context. | ||
func serverFromContext(ctx context.Context) *Server { | ||
s, _ := ctx.Value(serverKey{}).(*Server) | ||
return s | ||
} | ||
|
||
// contextWithServer sets the Server in the context. | ||
func contextWithServer(ctx context.Context, server *Server) context.Context { | ||
return context.WithValue(ctx, serverKey{}, server) | ||
} | ||
|
||
// isRegisteredMethod returns whether the passed in method is registered as a | ||
// method on the server. /service/method and service/method will match if the | ||
// service and method are registered on the server. | ||
func (s *Server) isRegisteredMethod(serviceMethod string) bool { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this very similar to the code that looks up the handler for a method? Can we reuse code somehow, or even just call it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it is very similar, but I chose not to do it because this line is very specific: https://github.com/grpc/grpc-go/blob/master/server.go#L1731, and I thought it would be hard to generalize what's shared into a helper. If you feel strongly about this after reading the handleStream() code, I can try and pull out logic into helper. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah please see what you can do. I think a method that returns There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried writing it, and it seems distinct enough logically to not warrant a helper. Your function signature doesn't distinguish between a. malformed method name (important here for tracing and logging: https://github.com/grpc/grpc-go/blob/master/server.go#L1732) b. registered unary rpc c. registered streaming name (important here: https://github.com/grpc/grpc-go/blob/master/server.go#L1770, to chose to go through processUnary or processStreaming). This helper's scope is not to log anything in traces or channelz, but to parse a :method header received from the wire and determine whether it's registered or not registered. I think it's sufficiently different to not warrant a new helper. |
||
if serviceMethod != "" && serviceMethod[0] == '/' { | ||
serviceMethod = serviceMethod[1:] | ||
} | ||
pos := strings.LastIndex(serviceMethod, "/") | ||
if pos == -1 { // Invalid method name syntax. | ||
return false | ||
} | ||
service := serviceMethod[:pos] | ||
method := serviceMethod[pos+1:] | ||
srv, knownService := s.services[service] | ||
if knownService { | ||
if _, ok := srv.methods[method]; ok { | ||
return true | ||
} | ||
if _, ok := srv.streams[method]; ok { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
// SetHeader sets the header metadata to be sent from the server to the client. | ||
// The context provided must be the context passed to the server's handler. | ||
// | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,12 @@ package stats_test | |
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"google.golang.org/grpc/internal" | ||
"google.golang.org/grpc/internal/stubserver" | ||
"google.golang.org/grpc/internal/stubstatshandler" | ||
"google.golang.org/grpc/internal/testutils" | ||
"io" | ||
"net" | ||
"reflect" | ||
|
@@ -1457,3 +1462,77 @@ func (s) TestMultipleServerStatsHandler(t *testing.T) { | |
t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4) | ||
} | ||
} | ||
|
||
// TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler | ||
// gets access to a Server on the server side, and thus the method that the | ||
// server owns which specifies whether a method is made or not. The test sets up | ||
// a server with a unary call and full duplex call configured, and makes an RPC. | ||
// Within the stats handler, asking the server whether unary or duplex method | ||
// names are registered should return true, and any other query should return | ||
// false. | ||
func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) { | ||
errorCh := testutils.NewChannel() | ||
stubStatsHandler := &stubstatshandler.StubStatsHandler{ | ||
TagRPCF: func(ctx context.Context, _ *stats.RPCTagInfo) context.Context { | ||
// OpenTelemetry instrumentation needs the passed in Server to determine if | ||
// methods are registered in different handle calls in to record metrics. | ||
// This tag RPC call context gets passed into every handle call, so can | ||
// assert once here, since it maps to all the handle RPC calls that come | ||
// after. These internal calls will be how the OpenTelemetry instrumentation | ||
// component accesses this server and the subsequent helper on the server. | ||
server := internal.ServerFromContext.(func(context.Context) *grpc.Server)(ctx) | ||
if server == nil { | ||
errorCh.Send("stats handler received ctx has no server present") | ||
} | ||
isRegisteredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool) | ||
// /s/m and s/m are valid. | ||
if !isRegisteredMethod(server, "/grpc.testing.TestService/UnaryCall") { | ||
errorCh.Send(errors.New("UnaryCall should be a registered method according to server")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't actually need the error channel for this. You can just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, right, this was when it was a defined stats handler rather than a stub with methods declared in test. Thus, before I couldn't close on t *testing.T, and now I can :). Switched. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Chose a waitGroup. |
||
return ctx | ||
} | ||
if !isRegisteredMethod(server, "grpc.testing.TestService/FullDuplexCall") { | ||
errorCh.Send(errors.New("FullDuplexCall should be a registered method according to server")) | ||
return ctx | ||
} | ||
if isRegisteredMethod(server, "/grpc.testing.TestService/DoesNotExistCall") { | ||
errorCh.Send(errors.New("DoesNotExistCall should not be a registered method according to server")) | ||
return ctx | ||
} | ||
if isRegisteredMethod(server, "/unknownService/UnaryCall") { | ||
errorCh.Send(errors.New("/unknownService/UnaryCall should not be a registered method according to server")) | ||
return ctx | ||
} | ||
errorCh.Send(nil) | ||
return ctx | ||
}, | ||
} | ||
ss := &stubserver.StubServer{ | ||
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { | ||
return &testpb.SimpleResponse{}, nil | ||
}, | ||
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You never call this; either delete it or test it please. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do test this. See line 1493. However, you are right that I never initiate a bidirectional streaming RPC from the test. Would you like me to add that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This closure being set or not doesn't change the test at line 1494. The method is registered because it's in the service descriptor. I'm fine with deleting this; I don't think there's any need to test both streaming and unary RPCs unless there are different code paths that lead to setting the server in the context, which there aren't. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, right, these closures != service descriptor, it's all there anyway. Deleting. |
||
for { | ||
if _, err := stream.Recv(); err == io.EOF { | ||
return nil | ||
} | ||
} | ||
}, | ||
} | ||
if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(stubStatsHandler)}); err != nil { | ||
t.Fatalf("Error starting endpoint server: %v", err) | ||
} | ||
defer ss.Stop() | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) | ||
defer cancel() | ||
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil { | ||
t.Fatalf("Unexpected error from UnaryCall: %v", err) | ||
} | ||
err, errRecv := errorCh.Receive(ctx) | ||
if errRecv != nil { | ||
t.Fatalf("error receiving from channel: %v", errRecv) | ||
} | ||
if err != nil { | ||
t.Fatalf("error received from error channel: %v", err) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This returns a
bool
right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, right, whoops. Added.