From 4313fb41de87e262b1ca8986eac4640d5484ff66 Mon Sep 17 00:00:00 2001 From: Josh Humphries Date: Wed, 16 Feb 2022 13:07:32 -0500 Subject: [PATCH] improve server reflection 1. Support alternate source of descriptors, like for RPC servers that get their descriptors dynamically and are dynamic proxies 2. Use the new protobuf API v2 stuff to get the descriptors, which is much more sane than the old APIs --- reflection/serverreflection.go | 340 +++++++++------------------- reflection/serverreflection_test.go | 35 +-- 2 files changed, 128 insertions(+), 247 deletions(-) diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index 9b387dddee58..850954ad77d6 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -37,16 +37,12 @@ To register server reflection on a gRPC server: package reflection // import "google.golang.org/grpc/reflection" import ( - "bytes" - "compress/gzip" - "fmt" + "errors" "io" - "io/ioutil" "sort" "sync" "github.com/golang/protobuf/proto" - dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" "google.golang.org/grpc" "google.golang.org/grpc/codes" rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" @@ -54,7 +50,6 @@ import ( "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" - "google.golang.org/protobuf/types/dynamicpb" ) // GRPCServer is the interface provided by a gRPC server. It is implemented by @@ -65,291 +60,165 @@ type GRPCServer interface { GetServiceInfo() map[string]grpc.ServiceInfo } +// ExtensionResolver is the interface used to query details about extensions. +// This interface is satisfied by protoregistry.GlobalTypes. +type ExtensionResolver interface { + protoregistry.ExtensionTypeResolver + RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) +} + var _ GRPCServer = (*grpc.Server)(nil) type serverReflectionServer struct { rpb.UnimplementedServerReflectionServer - s GRPCServer + s GRPCServer + descResolver protodesc.Resolver + extResolver ExtensionResolver initSymbols sync.Once serviceNames []string - symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files -} - -// Register registers the server reflection service on the given gRPC server. -func Register(s GRPCServer) { - rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ - s: s, - }) } -func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) { - s.initSymbols.Do(func() { - serviceInfo := s.s.GetServiceInfo() - - s.symbols = map[string]*dpb.FileDescriptorProto{} - s.serviceNames = make([]string, 0, len(serviceInfo)) - processed := map[string]struct{}{} - for svc, info := range serviceInfo { - s.serviceNames = append(s.serviceNames, svc) - fdenc, ok := parseMetadata(info.Metadata) - if !ok { - continue - } - fd, err := decodeFileDesc(fdenc) - if err != nil { - continue - } - s.processFile(fd, processed) - } - sort.Strings(s.serviceNames) - }) - - return s.serviceNames, s.symbols -} - -func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) { - filename := fd.GetName() - if _, ok := processed[filename]; ok { - return - } - processed[filename] = struct{}{} - - prefix := fd.GetPackage() - - for _, msg := range fd.MessageType { - s.processMessage(fd, prefix, msg) - } - for _, en := range fd.EnumType { - s.processEnum(fd, prefix, en) - } - for _, ext := range fd.Extension { - s.processField(fd, prefix, ext) - } - for _, svc := range fd.Service { - svcName := fqn(prefix, svc.GetName()) - s.symbols[svcName] = fd - for _, meth := range svc.Method { - name := fqn(svcName, meth.GetName()) - s.symbols[name] = fd - } - } - - for _, dep := range fd.Dependency { - fdenc := proto.FileDescriptor(dep) - fdDep, err := decodeFileDesc(fdenc) - if err != nil { - continue - } - s.processFile(fdDep, processed) - } +// ServerOptions represents the options used to construct a reflection server. +// +// Either Server or ServiceNames must be populated, but not both. These control +// what services are advertised by the server in the ListServices capability of +// the reflection service. If neither is provided, the returned server can still +// server descriptors, but it will advertise no service names. +// +// The given DescriptorResolver will be used to resolve symbols and files by +// name. If not present, protoregistry.GlobalFiles will be used. The given +// ExtensionResolver will be used to resolve extensions. If not present, +// protoregistry.GlobalTypes will be used. +type ServerOptions struct { + // An RPC server, whose exposed services are made available via service + // reflection. + Server GRPCServer + // The list of service names. This should only be populated if Server is + // nil. + ServiceNames []string + // Optional resolver used to load descriptors. + DescriptorResolver protodesc.Resolver + // Optional resolver used to query for known extensions. + ExtensionResolver ExtensionResolver } -func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) { - msgName := fqn(prefix, msg.GetName()) - s.symbols[msgName] = fd - - for _, nested := range msg.NestedType { - s.processMessage(fd, msgName, nested) - } - for _, en := range msg.EnumType { - s.processEnum(fd, msgName, en) +// NewServer returns a reflection server implementation using the given options. +// It returns an error if the given options are invalid. +func NewServer(opts ServerOptions) (rpb.ServerReflectionServer, error) { + if opts.Server != nil && len(opts.ServiceNames) > 0 { + return nil, errors.New("options must specify either Server or ServiceNames, not both") } - for _, ext := range msg.Extension { - s.processField(fd, msgName, ext) + if opts.DescriptorResolver == nil { + opts.DescriptorResolver = protoregistry.GlobalFiles } - for _, fld := range msg.Field { - s.processField(fd, msgName, fld) - } - for _, oneof := range msg.OneofDecl { - oneofName := fqn(msgName, oneof.GetName()) - s.symbols[oneofName] = fd + if opts.ExtensionResolver == nil { + opts.ExtensionResolver = protoregistry.GlobalTypes } + return &serverReflectionServer{ + s: opts.Server, + descResolver: opts.DescriptorResolver, + extResolver: opts.ExtensionResolver, + serviceNames: opts.ServiceNames, + }, nil } -func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) { - enName := fqn(prefix, en.GetName()) - s.symbols[enName] = fd - - for _, val := range en.Value { - valName := fqn(enName, val.GetName()) - s.symbols[valName] = fd - } -} - -func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) { - fldName := fqn(prefix, fld.GetName()) - s.symbols[fldName] = fd -} - -func fqn(prefix, name string) string { - if prefix == "" { - return name - } - return prefix + "." + name -} - -// decodeFileDesc does decompression and unmarshalling on the given -// file descriptor byte slice. -func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { - raw, err := decompress(enc) - if err != nil { - return nil, fmt.Errorf("failed to decompress enc: %v", err) - } - - fd := new(dpb.FileDescriptorProto) - if err := proto.Unmarshal(raw, fd); err != nil { - return nil, fmt.Errorf("bad descriptor: %v", err) - } - return fd, nil -} - -// decompress does gzip decompression. -func decompress(b []byte) ([]byte, error) { - r, err := gzip.NewReader(bytes.NewReader(b)) - if err != nil { - return nil, fmt.Errorf("bad gzipped descriptor: %v", err) - } - out, err := ioutil.ReadAll(r) +// Register registers the server reflection service on the given gRPC server. +func Register(s GRPCServer) { + svr, err := NewServer(ServerOptions{Server: s}) if err != nil { - return nil, fmt.Errorf("bad gzipped descriptor: %v", err) + panic(err) // should not be possible } - return out, nil + rpb.RegisterServerReflectionServer(s, svr) } -func fileDescContainingExtension(typeName string, ext int32) (*dpb.FileDescriptorProto, error) { - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(typeName)) - if err != nil { - return nil, err - } - m := dynamicpb.NewMessage(desc.(protoreflect.MessageDescriptor)) - - var extDesc *proto.ExtensionDesc - for id, desc := range proto.RegisteredExtensions(m) { - if id == ext { - extDesc = desc - break +func (s *serverReflectionServer) init() { + s.initSymbols.Do(func() { + if s.s == nil { + // no need to init; service names were specified at construction + return } - } - - if extDesc == nil { - return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext) - } - - return decodeFileDesc(proto.FileDescriptor(extDesc.Filename)) + serviceInfo := s.s.GetServiceInfo() + s.serviceNames = make([]string, 0, len(serviceInfo)) + for svc := range serviceInfo { + s.serviceNames = append(s.serviceNames, svc) + } + sort.Strings(s.serviceNames) + }) } // fileDescWithDependencies returns a slice of serialized fileDescriptors in // wire format ([]byte). The fileDescriptors will include fd and all the // transitive dependencies of fd with names not in sentFileDescriptors. -func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) { - r := [][]byte{} - queue := []*dpb.FileDescriptorProto{fd} +func (s *serverReflectionServer) fileDescWithDependencies(fd protoreflect.FileDescriptor, sentFileDescriptors map[string]struct{}) ([][]byte, error) { + var r [][]byte + queue := []protoreflect.FileDescriptor{fd} for len(queue) > 0 { currentfd := queue[0] queue = queue[1:] - if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent { - sentFileDescriptors[currentfd.GetName()] = true - currentfdEncoded, err := proto.Marshal(currentfd) + if _, sent := sentFileDescriptors[currentfd.Path()]; len(r) == 0 || !sent { + sentFileDescriptors[currentfd.Path()] = struct{}{} + fdProto := protodesc.ToFileDescriptorProto(currentfd) + currentfdEncoded, err := proto.Marshal(fdProto) if err != nil { return nil, err } r = append(r, currentfdEncoded) } - for _, dep := range currentfd.Dependency { - fdenc := proto.FileDescriptor(dep) - fdDep, err := decodeFileDesc(fdenc) - if err != nil { - continue - } - queue = append(queue, fdDep) + for i := 0; i < currentfd.Imports().Len(); i++ { + queue = append(queue, currentfd.Imports().Get(i)) } } return r, nil } -// fileDescEncodingByFilename finds the file descriptor for given filename, -// finds all of its previously unsent transitive dependencies, does marshalling -// on them, and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { - enc := proto.FileDescriptor(name) - if enc == nil { - return nil, fmt.Errorf("unknown file: %v", name) - } - fd, err := decodeFileDesc(enc) - if err != nil { - return nil, err - } - return fileDescWithDependencies(fd, sentFileDescriptors) -} - -// parseMetadata finds the file descriptor bytes specified meta. -// For SupportPackageIsVersion4, m is the name of the proto file, we -// call proto.FileDescriptor to get the byte slice. -// For SupportPackageIsVersion3, m is a byte slice itself. -func parseMetadata(meta interface{}) ([]byte, bool) { - // Check if meta is the file name. - if fileNameForMeta, ok := meta.(string); ok { - return proto.FileDescriptor(fileNameForMeta), true - } - - // Check if meta is the byte slice. - if enc, ok := meta.([]byte); ok { - return enc, true - } - - return nil, false -} - // fileDescEncodingContainingSymbol finds the file descriptor containing the // given symbol, finds all of its previously unsent transitive dependencies, // does marshalling on them, and returns the marshalled result. The given symbol // can be a type, a service or a method. -func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { - _, symbols := s.getSymbols() - fd := symbols[name] - if fd == nil { - // Check if it's a type name that was not present in the - // transitive dependencies of the registered services. - desc, err := protoregistry.GlobalTypes.FindMessageByName(protoreflect.FullName(name)) - if err != nil { - return nil, err - } - fd = protodesc.ToFileDescriptorProto(desc.Descriptor().ParentFile()) +func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]struct{}) ([][]byte, error) { + d, err := s.descResolver.FindDescriptorByName(protoreflect.FullName(name)) + if err != nil { + return nil, err } - return fileDescWithDependencies(fd, sentFileDescriptors) + return s.fileDescWithDependencies(d.ParentFile(), sentFileDescriptors) } // fileDescEncodingContainingExtension finds the file descriptor containing // given extension, finds all of its previously unsent transitive dependencies, // does marshalling on them, and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) { - fd, err := fileDescContainingExtension(typeName, extNum) +func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]struct{}) ([][]byte, error) { + xt, err := s.extResolver.FindExtensionByNumber(protoreflect.FullName(typeName), protoreflect.FieldNumber(extNum)) if err != nil { return nil, err } - return fileDescWithDependencies(fd, sentFileDescriptors) + return s.fileDescWithDependencies(xt.TypeDescriptor().ParentFile(), sentFileDescriptors) } // allExtensionNumbersForTypeName returns all extension numbers for the given type. func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) { - desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(name)) - if err != nil { - return nil, err - } - m := dynamicpb.NewMessage(desc.(protoreflect.MessageDescriptor)) - - exts := proto.RegisteredExtensions(m) - extNums := make([]int32, 0, len(exts)) - for id := range exts { - extNums = append(extNums, id) + var numbers []int32 + s.extResolver.RangeExtensionsByMessage(protoreflect.FullName(name), func(xt protoreflect.ExtensionType) bool { + numbers = append(numbers, int32(xt.TypeDescriptor().Number())) + return true + }) + sort.Slice(numbers, func(i, j int) bool { + return numbers[i] < numbers[j] + }) + if len(numbers) == 0 { + // maybe return an error if given type name is not known + if _, err := s.descResolver.FindDescriptorByName(protoreflect.FullName(name)); err != nil { + return nil, err + } } - return extNums, nil + return numbers, nil } // ServerReflectionInfo is the reflection service handler. func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { - sentFileDescriptors := make(map[string]bool) + s.init() + + sentFileDescriptors := make(map[string]struct{}) for { in, err := stream.Recv() if err == io.EOF { @@ -365,7 +234,11 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } switch req := in.MessageRequest.(type) { case *rpb.ServerReflectionRequest_FileByFilename: - b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors) + var b [][]byte + fd, err := s.descResolver.FindFileByPath(req.FileByFilename) + if err == nil { + b, err = s.fileDescWithDependencies(fd, sentFileDescriptors) + } if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ @@ -426,9 +299,8 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } case *rpb.ServerReflectionRequest_ListServices: - svcNames, _ := s.getSymbols() - serviceResponses := make([]*rpb.ServiceResponse, len(svcNames)) - for i, n := range svcNames { + serviceResponses := make([]*rpb.ServiceResponse, len(s.serviceNames)) + for i, n := range s.serviceNames { serviceResponses[i] = &rpb.ServiceResponse{ Name: n, } diff --git a/reflection/serverreflection_test.go b/reflection/serverreflection_test.go index 730f51bf012d..ff1369fc1015 100644 --- a/reflection/serverreflection_test.go +++ b/reflection/serverreflection_test.go @@ -43,7 +43,10 @@ import ( ) var ( - s = &serverReflectionServer{} + s = &serverReflectionServer{ + descResolver: protoregistry.GlobalFiles, + extResolver: protoregistry.GlobalTypes, + } // fileDescriptor of each test proto file. fdTest *dpb.FileDescriptorProto fdTestv3 *dpb.FileDescriptorProto @@ -73,19 +76,16 @@ func Test(t *testing.T) { } func loadFileDesc(filename string) (*dpb.FileDescriptorProto, []byte) { - enc := proto.FileDescriptor(filename) - if enc == nil { - panic(fmt.Sprintf("failed to find fd for file: %v", filename)) - } - fd, err := decodeFileDesc(enc) + fd, err := protoregistry.GlobalFiles.FindFileByPath(filename) if err != nil { - panic(fmt.Sprintf("failed to decode enc: %v", err)) + panic(err) } - b, err := proto.Marshal(fd) + fdProto := protodesc.ToFileDescriptorProto(fd) + b, err := proto.Marshal(fdProto) if err != nil { panic(fmt.Sprintf("failed to marshal fd: %v", err)) } - return fd, b + return fdProto, b } func loadFileDescDynamic(b []byte) (*dpb.FileDescriptorProto, protoreflect.FileDescriptor, []byte) { @@ -135,9 +135,18 @@ func (x) TestFileDescContainingExtension(t *testing.T) { {"grpc.testing.ToBeExtended", 23, fdProto2Ext2}, {"grpc.testing.ToBeExtended", 29, fdProto2Ext2}, } { - fd, err := fileDescContainingExtension(test.st, test.extNum) - if err != nil || !proto.Equal(fd, test.want) { - t.Errorf("fileDescContainingExtension(%q) = %q, %v, want %q, ", test.st, fd, err, test.want) + fd, err := s.fileDescEncodingContainingExtension(test.st, test.extNum, map[string]struct{}{}) + if err != nil { + t.Errorf("fileDescContainingExtension(%q) return error: %v", test.st, err) + continue + } + var actualFd descriptorpb.FileDescriptorProto + if err := proto.Unmarshal(fd[0], &actualFd); err != nil { + t.Errorf("fileDescContainingExtension(%q) return invalid bytes: %v", test.st, err) + continue + } + if !proto.Equal(&actualFd, test.want) { + t.Errorf("fileDescContainingExtension(%q) returned %q, but wanted %q", test.st, &actualFd, test.want) } } } @@ -348,7 +357,7 @@ func testFileContainingSymbol(t *testing.T, stream rpb.ServerReflection_ServerRe {"grpc.testingv3.SearchResponseV3.Result.Value.val", fdTestv3Byte}, {"grpc.testingv3.SearchResponseV3.Result.Value.str", fdTestv3Byte}, {"grpc.testingv3.SearchResponseV3.State", fdTestv3Byte}, - {"grpc.testingv3.SearchResponseV3.State.FRESH", fdTestv3Byte}, + {"grpc.testingv3.SearchResponseV3.FRESH", fdTestv3Byte}, // Test dynamic symbols {"grpc.testing.DynamicService", fdDynamicByte}, {"grpc.testing.DynamicReq", fdDynamicByte},