From 5a17c583c3255d666cc9aaea4b27e9da503427f0 Mon Sep 17 00:00:00 2001 From: Cam Hutchison Date: Sat, 23 Apr 2022 12:25:35 +1000 Subject: [PATCH] pb: Preload well-known types, remove -D flag Pre-load all the well-known types in `google.golang.org/protobuf/types` so they do not need to be present in a protoset. This allows us to remove the `-D` flag that decodes the input as a `FileDescriptorSet` replacing it with requiring that the message type `FileDescriptorSet` be present on the command line. e.g. protoc -o/dev/stdout foo.proto | pb -Ip FileDescriptorSet This makes pb a little more flexible in that any of the well-known types can be decoded without needing them described in a protoset, rather than just `FileDescriptorSet`. A planned later change to allow message name aliases should make this simpler by allowing `fds` to be used as the message type instead of the more verbose `FileDescriptorSet`. Another planned change to detect binary input data as the binary protobuf format means soon the above command could be expressed as: protoc -o/dev/stdout foo.proto | pb fds --- cmd/pb/main.go | 144 ++++++++++++++++---------------------------- cmd/pb/main_test.go | 53 ++++++++++------ 2 files changed, 87 insertions(+), 110 deletions(-) diff --git a/cmd/pb/main.go b/cmd/pb/main.go index ad0e13e..9837585 100644 --- a/cmd/pb/main.go +++ b/cmd/pb/main.go @@ -1,7 +1,6 @@ package main import ( - "errors" "fmt" "io" "os" @@ -16,8 +15,19 @@ import ( "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" - "google.golang.org/protobuf/types/dynamicpb" + _ "google.golang.org/protobuf/types/known/anypb" + _ "google.golang.org/protobuf/types/known/apipb" + _ "google.golang.org/protobuf/types/known/durationpb" + _ "google.golang.org/protobuf/types/known/emptypb" + _ "google.golang.org/protobuf/types/known/fieldmaskpb" + _ "google.golang.org/protobuf/types/known/sourcecontextpb" + _ "google.golang.org/protobuf/types/known/structpb" + _ "google.golang.org/protobuf/types/known/timestamppb" + _ "google.golang.org/protobuf/types/known/typepb" + _ "google.golang.org/protobuf/types/known/wrapperspb" + _ "google.golang.org/protobuf/types/pluginpb" ) var ( @@ -36,21 +46,23 @@ pb translates encoded Protobuf message from one format to another ) type PBConfig struct { - Protoset *registry.Files `short:"P" help:"Protoset of Message being translated" xor:"protoset"` - Descriptorpb bool `short:"D" help:"Use descriptorpb as protoset" xor:"protoset"` - Out string `short:"o" help:"Output file name"` - InFormat string `short:"I" help:"Input format (j[son], p[b], t[xt])" enum:"json,pb,txt,j,p,t," default:""` - OutFormat string `short:"O" help:"Output format (j[son], p[b], t[xt])" enum:"json,pb,txt,j,p,t," default:""` - Zero bool `short:"z" help:"Print zero values in JSON output"` - MessageType string `arg:"" help:"Message type to be translated" optional:""` - In string `arg:"" help:"Message value JSON encoded" optional:""` + Protoset *descriptorpb.FileDescriptorSet `short:"P" help:"Protoset containing Message to be translated"` + + Out string `short:"o" help:"Output file name"` + InFormat string `short:"I" help:"Input format (j[son], p[b], t[xt])" enum:"json,pb,txt,j,p,t," default:""` + OutFormat string `short:"O" help:"Output format (j[son], p[b], t[xt])" enum:"json,pb,txt,j,p,t," default:""` + Zero bool `short:"z" help:"Print zero values in JSON output"` + MessageType string `arg:"" help:"Message type to be translated"` + In string `arg:"" help:"Message value JSON encoded" optional:""` + + types *protoregistry.Types } func main() { kctx := kong.Parse(&cli, kong.Description(description), kong.Vars{"version": fmt.Sprintf("%s (%s on %s)", version, commit, date)}, - kong.TypeMapper(reflect.TypeOf(cli.PBConfig.Protoset), kong.MapperFunc(registryMapper)), + kong.TypeMapper(reflect.TypeOf(cli.PBConfig.Protoset), kong.MapperFunc(fdsMapper)), ) kctx.FatalIfErrorf(kctx.Run()) } @@ -59,29 +71,14 @@ type unmarshaler func([]byte, proto.Message) error type marshaler func(proto.Message) ([]byte, error) func (c *PBConfig) Run() error { - if c.Descriptorpb { - c.Protoset = ®istry.Files{} - err := c.Protoset.RegisterFile(descriptorpb.File_google_protobuf_descriptor_proto) - if err != nil { + c.types = registry.CloneTypes(protoregistry.GlobalTypes) + if c.Protoset != nil { + if err := registry.AddDynamicTypes(c.types, c.Protoset); err != nil { return err } - if c.MessageType != "" && c.In == "" { - // shuffle down the args and provide default MessageType - c.In = c.MessageType - c.MessageType = "" - } - if c.MessageType == "" { - c.MessageType = ".google.protobuf.FileDescriptorSet" - } - if c.In == "" && c.InFormat == "" { - c.InFormat = "pb" - } - } - if c.MessageType == "" { - return errors.New(`expected ""`) } - md, err := lookupMessage(c.Protoset, c.MessageType) + mt, err := lookupMessage(c.types, c.MessageType) if err != nil { return err } @@ -93,7 +90,7 @@ func (c *PBConfig) Run() error { if err != nil { return fmt.Errorf("cannot decode %q input: %w", c.inFormat(), err) } - message := dynamicpb.NewMessage(md) + message := mt.New().Interface() if err := unmarshal(in, message); err != nil { return err } @@ -112,9 +109,6 @@ func (c *PBConfig) AfterApply() error { if c.Zero && c.outFormat() != "json" { return fmt.Errorf(`cannot print zero values with %q, only "json"`, c.outFormat()) } - if !c.Descriptorpb && c.Protoset == nil { - return errors.New(`either "-p/--protoset=PROTOSET" or -D/--descriptorpb required`) - } return nil } @@ -142,13 +136,13 @@ func (c *PBConfig) writeOutput(b []byte) error { func (c *PBConfig) unmarshaler() (unmarshaler, error) { switch c.inFormat() { case "json": - o := protojson.UnmarshalOptions{Resolver: c.Protoset} + o := protojson.UnmarshalOptions{Resolver: c.types} return o.Unmarshal, nil case "pb": - o := proto.UnmarshalOptions{Resolver: c.Protoset} + o := proto.UnmarshalOptions{Resolver: c.types} return o.Unmarshal, nil case "txt": - o := prototext.UnmarshalOptions{Resolver: c.Protoset} + o := prototext.UnmarshalOptions{Resolver: c.types} return o.Unmarshal, nil } return nil, fmt.Errorf("unknown input format %q", c.inFormat()) @@ -166,7 +160,7 @@ func (c *PBConfig) marshaler() (marshaler, error) { switch c.outFormat() { case "json": o := protojson.MarshalOptions{ - Resolver: c.Protoset, + Resolver: c.types, Multiline: true, EmitUnpopulated: c.Zero, } @@ -181,7 +175,7 @@ func (c *PBConfig) marshaler() (marshaler, error) { o := proto.MarshalOptions{} return o.Marshal, nil case "txt": - o := prototext.MarshalOptions{Resolver: c.Protoset, Multiline: true} + o := prototext.MarshalOptions{Resolver: c.types, Multiline: true} return o.Marshal, nil } return nil, fmt.Errorf("unknown output format %s", c.outFormat()) @@ -211,17 +205,20 @@ func canonicalFormat(format string) string { return format } -func lookupMessage(reg *registry.Files, name string) (protoreflect.MessageDescriptor, error) { - var result []protoreflect.MessageDescriptor - reg.RangeFiles(func(fd protoreflect.FileDescriptor) bool { - for i := 0; i < fd.Messages().Len(); i++ { - md := fd.Messages().Get(i) - mds, exactMatch := lookupMessageInMD(md, name) - if exactMatch { - result = mds - return false - } - result = append(result, mds...) +func lookupMessage(types *protoregistry.Types, name string) (protoreflect.MessageType, error) { + var result []protoreflect.MessageType + types.RangeMessages(func(mt protoreflect.MessageType) bool { + mdName := string(mt.Descriptor().FullName()) + if name == mdName || name == "."+mdName { + // If we have a full name match, we're done and will also + // ignore any other partial name matches. + result = []protoreflect.MessageType{mt} + return false + } + mdLowerName := "." + strings.ToLower(mdName) + lowerName := strings.ToLower(name) + if lowerName == mdLowerName || strings.HasSuffix(mdLowerName, "."+lowerName) { + result = append(result, mt) } return true }) @@ -235,57 +232,20 @@ func lookupMessage(reg *registry.Files, name string) (protoreflect.MessageDescri return result[0], nil } -func lookupMessageInMD(md protoreflect.MessageDescriptor, name string) (mds []protoreflect.MessageDescriptor, exactMatch bool) { - mdName := string(md.FullName()) - if name == mdName || name == "."+mdName { - // If we have a full name match, we're done and will also - // ignore any other partial name matches. - return []protoreflect.MessageDescriptor{md}, true - } - mdLowerName := "." + strings.ToLower(mdName) - lowerName := strings.ToLower(name) - if lowerName == mdLowerName || strings.HasSuffix(mdLowerName, "."+lowerName) { - mds = append(mds, md) - } - subMessages := md.Messages() - for i := 0; i < subMessages.Len(); i++ { - md = subMessages.Get(i) - subMDs, exactMatch := lookupMessageInMD(md, name) - if exactMatch { - return subMDs, true - } - mds = append(mds, subMDs...) - } - return mds, false -} - -func registryMapper(kctx *kong.DecodeContext, target reflect.Value) error { - reg, ok := target.Interface().(*registry.Files) +func fdsMapper(kctx *kong.DecodeContext, target reflect.Value) error { + fds, ok := target.Interface().(*descriptorpb.FileDescriptorSet) if !ok { - panic("target is not a *registry.Files") + panic("target is not a *descriptorpb.FileDescriptorSet") } var filename string if err := kctx.Scan.PopValueInto("file", &filename); err != nil { return err } - files, err := registryFiles(filename) - if err != nil { - return err - } - *reg = *files - return nil -} - -func registryFiles(filename string) (*registry.Files, error) { b, err := os.ReadFile(filename) if err != nil { - return nil, err - } - fds := descriptorpb.FileDescriptorSet{} - if err := proto.Unmarshal(b, &fds); err != nil { - return nil, err + return err } - return registry.NewFiles(&fds) + return proto.Unmarshal(b, fds) } func isTTY() bool { diff --git a/cmd/pb/main_test.go b/cmd/pb/main_test.go index 2215ed9..6cbc602 100644 --- a/cmd/pb/main_test.go +++ b/cmd/pb/main_test.go @@ -7,15 +7,26 @@ import ( "testing" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" ) +func newFDS(t *testing.T, filename string) *descriptorpb.FileDescriptorSet { + t.Helper() + b, err := os.ReadFile(filename) + require.NoError(t, err) + fds := descriptorpb.FileDescriptorSet{} + err = proto.Unmarshal(b, &fds) + require.NoError(t, err) + return &fds +} + func TestRunJSON(t *testing.T) { tmpDir := t.TempDir() - files, err := registryFiles("testdata/pbtest.pb") - require.NoError(t, err) + fds := newFDS(t, "testdata/pbtest.pb") cli := PBConfig{ - Protoset: files, + Protoset: fds, Out: filepath.Join(tmpDir, "out.json"), MessageType: "BaseMessage", In: `{"f": "F" }`, @@ -34,11 +45,10 @@ func TestRunJSON(t *testing.T) { func TestRunJSONZero(t *testing.T) { tmpDir := t.TempDir() - files, err := registryFiles("testdata/pbtest.pb") - require.NoError(t, err) + fds := newFDS(t, "testdata/pbtest.pb") cli := PBConfig{ - Protoset: files, + Protoset: fds, Out: filepath.Join(tmpDir, "out.json"), MessageType: "BaseMessage", In: `{"f": "" }`, @@ -57,11 +67,10 @@ func TestRunJSONZero(t *testing.T) { func TestRunPrototext(t *testing.T) { tmpDir := t.TempDir() - files, err := registryFiles("testdata/pbtest.pb") - require.NoError(t, err) + fds := newFDS(t, "testdata/pbtest.pb") cli := PBConfig{ - Protoset: files, + Protoset: fds, Out: filepath.Join(tmpDir, "out.txt"), MessageType: "BaseMessage", In: `{"f": "F" }`, @@ -84,11 +93,10 @@ func TestRunPrototext(t *testing.T) { func TestRunMessages(t *testing.T) { tmpDir := t.TempDir() - files, err := registryFiles("testdata/pbtest.pb") - require.NoError(t, err) + fds := newFDS(t, "testdata/pbtest.pb") cli := PBConfig{ - Protoset: files, + Protoset: fds, Out: filepath.Join(tmpDir, "out.json"), In: `{"f": "F" }`, } @@ -105,11 +113,10 @@ func TestRunMessages(t *testing.T) { func TestRunMessageErr(t *testing.T) { tmpDir := t.TempDir() - files, err := registryFiles("testdata/pbtest.pb") - require.NoError(t, err) + fds := newFDS(t, "testdata/pbtest.pb") cli := PBConfig{ - Protoset: files, + Protoset: fds, Out: filepath.Join(tmpDir, "out.json"), In: `{"f": "F" }`, } @@ -124,11 +131,10 @@ func TestRunMessageErr(t *testing.T) { func TestRunInErr(t *testing.T) { tmpDir := t.TempDir() - files, err := registryFiles("testdata/pbtest.pb") - require.NoError(t, err) + fds := newFDS(t, "testdata/pbtest.pb") cli := PBConfig{ - Protoset: files, + Protoset: fds, Out: filepath.Join(tmpDir, "out.json"), MessageType: "BaseMessage", In: `{"MISSING": "F" }`, @@ -136,6 +142,17 @@ func TestRunInErr(t *testing.T) { require.Error(t, cli.Run()) } +func TestWellKnown(t *testing.T) { + tmpDir := t.TempDir() + cli := PBConfig{ + Out: filepath.Join(tmpDir, "out.json"), + MessageType: "Duration", + In: `"10s"`, + } + require.NoError(t, cli.Run()) + requireJSONFileContent(t, `"10s"`, cli.Out) +} + func requireJSONFileContent(t *testing.T, wantStr string, gotFile string) { t.Helper() b, err := os.ReadFile(gotFile)