Skip to content

Commit

Permalink
pb: Preload well-known types, remove -D flag
Browse files Browse the repository at this point in the history
Pre-load all the well-known types in `google.golang.org/protobuf/types`
so they do not need to be present in a protoset. This allows us to
remove the `-D` flag that decodes the input as a `FileDescriptorSet`
replacing it with requiring that the message type `FileDescriptorSet` be
present on the command line.

e.g.

    protoc -o/dev/stdout foo.proto | pb -Ip FileDescriptorSet

This makes pb a little more flexible in that any of the well-known types
can be decoded without needing them described in a protoset, rather than
just `FileDescriptorSet`.

A planned later change to allow message name aliases should make this
simpler by allowing `fds` to be used as the message type instead of the
more verbose `FileDescriptorSet`. Another planned change to detect
binary input data as the binary protobuf format means soon the above
command could be expressed as:

    protoc -o/dev/stdout foo.proto | pb fds
  • Loading branch information
camh- committed Apr 23, 2022
1 parent 7029a23 commit 5a17c58
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 110 deletions.
144 changes: 52 additions & 92 deletions cmd/pb/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"errors"
"fmt"
"io"
"os"
Expand All @@ -16,8 +15,19 @@ import (
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"
_ "google.golang.org/protobuf/types/known/anypb"
_ "google.golang.org/protobuf/types/known/apipb"
_ "google.golang.org/protobuf/types/known/durationpb"
_ "google.golang.org/protobuf/types/known/emptypb"
_ "google.golang.org/protobuf/types/known/fieldmaskpb"
_ "google.golang.org/protobuf/types/known/sourcecontextpb"
_ "google.golang.org/protobuf/types/known/structpb"
_ "google.golang.org/protobuf/types/known/timestamppb"
_ "google.golang.org/protobuf/types/known/typepb"
_ "google.golang.org/protobuf/types/known/wrapperspb"
_ "google.golang.org/protobuf/types/pluginpb"
)

var (
Expand All @@ -36,21 +46,23 @@ pb translates encoded Protobuf message from one format to another
)

type PBConfig struct {
Protoset *registry.Files `short:"P" help:"Protoset of Message being translated" xor:"protoset"`
Descriptorpb bool `short:"D" help:"Use descriptorpb as protoset" xor:"protoset"`
Out string `short:"o" help:"Output file name"`
InFormat string `short:"I" help:"Input format (j[son], p[b], t[xt])" enum:"json,pb,txt,j,p,t," default:""`
OutFormat string `short:"O" help:"Output format (j[son], p[b], t[xt])" enum:"json,pb,txt,j,p,t," default:""`
Zero bool `short:"z" help:"Print zero values in JSON output"`
MessageType string `arg:"" help:"Message type to be translated" optional:""`
In string `arg:"" help:"Message value JSON encoded" optional:""`
Protoset *descriptorpb.FileDescriptorSet `short:"P" help:"Protoset containing Message to be translated"`

Out string `short:"o" help:"Output file name"`
InFormat string `short:"I" help:"Input format (j[son], p[b], t[xt])" enum:"json,pb,txt,j,p,t," default:""`
OutFormat string `short:"O" help:"Output format (j[son], p[b], t[xt])" enum:"json,pb,txt,j,p,t," default:""`
Zero bool `short:"z" help:"Print zero values in JSON output"`
MessageType string `arg:"" help:"Message type to be translated"`
In string `arg:"" help:"Message value JSON encoded" optional:""`

types *protoregistry.Types
}

func main() {
kctx := kong.Parse(&cli,
kong.Description(description),
kong.Vars{"version": fmt.Sprintf("%s (%s on %s)", version, commit, date)},
kong.TypeMapper(reflect.TypeOf(cli.PBConfig.Protoset), kong.MapperFunc(registryMapper)),
kong.TypeMapper(reflect.TypeOf(cli.PBConfig.Protoset), kong.MapperFunc(fdsMapper)),
)
kctx.FatalIfErrorf(kctx.Run())
}
Expand All @@ -59,29 +71,14 @@ type unmarshaler func([]byte, proto.Message) error
type marshaler func(proto.Message) ([]byte, error)

func (c *PBConfig) Run() error {
if c.Descriptorpb {
c.Protoset = &registry.Files{}
err := c.Protoset.RegisterFile(descriptorpb.File_google_protobuf_descriptor_proto)
if err != nil {
c.types = registry.CloneTypes(protoregistry.GlobalTypes)
if c.Protoset != nil {
if err := registry.AddDynamicTypes(c.types, c.Protoset); err != nil {
return err
}
if c.MessageType != "" && c.In == "" {
// shuffle down the args and provide default MessageType
c.In = c.MessageType
c.MessageType = ""
}
if c.MessageType == "" {
c.MessageType = ".google.protobuf.FileDescriptorSet"
}
if c.In == "" && c.InFormat == "" {
c.InFormat = "pb"
}
}
if c.MessageType == "" {
return errors.New(`expected "<message-type>"`)
}

md, err := lookupMessage(c.Protoset, c.MessageType)
mt, err := lookupMessage(c.types, c.MessageType)
if err != nil {
return err
}
Expand All @@ -93,7 +90,7 @@ func (c *PBConfig) Run() error {
if err != nil {
return fmt.Errorf("cannot decode %q input: %w", c.inFormat(), err)
}
message := dynamicpb.NewMessage(md)
message := mt.New().Interface()
if err := unmarshal(in, message); err != nil {
return err
}
Expand All @@ -112,9 +109,6 @@ func (c *PBConfig) AfterApply() error {
if c.Zero && c.outFormat() != "json" {
return fmt.Errorf(`cannot print zero values with %q, only "json"`, c.outFormat())
}
if !c.Descriptorpb && c.Protoset == nil {
return errors.New(`either "-p/--protoset=PROTOSET" or -D/--descriptorpb required`)
}
return nil
}

Expand Down Expand Up @@ -142,13 +136,13 @@ func (c *PBConfig) writeOutput(b []byte) error {
func (c *PBConfig) unmarshaler() (unmarshaler, error) {
switch c.inFormat() {
case "json":
o := protojson.UnmarshalOptions{Resolver: c.Protoset}
o := protojson.UnmarshalOptions{Resolver: c.types}
return o.Unmarshal, nil
case "pb":
o := proto.UnmarshalOptions{Resolver: c.Protoset}
o := proto.UnmarshalOptions{Resolver: c.types}
return o.Unmarshal, nil
case "txt":
o := prototext.UnmarshalOptions{Resolver: c.Protoset}
o := prototext.UnmarshalOptions{Resolver: c.types}
return o.Unmarshal, nil
}
return nil, fmt.Errorf("unknown input format %q", c.inFormat())
Expand All @@ -166,7 +160,7 @@ func (c *PBConfig) marshaler() (marshaler, error) {
switch c.outFormat() {
case "json":
o := protojson.MarshalOptions{
Resolver: c.Protoset,
Resolver: c.types,
Multiline: true,
EmitUnpopulated: c.Zero,
}
Expand All @@ -181,7 +175,7 @@ func (c *PBConfig) marshaler() (marshaler, error) {
o := proto.MarshalOptions{}
return o.Marshal, nil
case "txt":
o := prototext.MarshalOptions{Resolver: c.Protoset, Multiline: true}
o := prototext.MarshalOptions{Resolver: c.types, Multiline: true}
return o.Marshal, nil
}
return nil, fmt.Errorf("unknown output format %s", c.outFormat())
Expand Down Expand Up @@ -211,17 +205,20 @@ func canonicalFormat(format string) string {
return format
}

func lookupMessage(reg *registry.Files, name string) (protoreflect.MessageDescriptor, error) {
var result []protoreflect.MessageDescriptor
reg.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
for i := 0; i < fd.Messages().Len(); i++ {
md := fd.Messages().Get(i)
mds, exactMatch := lookupMessageInMD(md, name)
if exactMatch {
result = mds
return false
}
result = append(result, mds...)
func lookupMessage(types *protoregistry.Types, name string) (protoreflect.MessageType, error) {
var result []protoreflect.MessageType
types.RangeMessages(func(mt protoreflect.MessageType) bool {
mdName := string(mt.Descriptor().FullName())
if name == mdName || name == "."+mdName {
// If we have a full name match, we're done and will also
// ignore any other partial name matches.
result = []protoreflect.MessageType{mt}
return false
}
mdLowerName := "." + strings.ToLower(mdName)
lowerName := strings.ToLower(name)
if lowerName == mdLowerName || strings.HasSuffix(mdLowerName, "."+lowerName) {
result = append(result, mt)
}
return true
})
Expand All @@ -235,57 +232,20 @@ func lookupMessage(reg *registry.Files, name string) (protoreflect.MessageDescri
return result[0], nil
}

func lookupMessageInMD(md protoreflect.MessageDescriptor, name string) (mds []protoreflect.MessageDescriptor, exactMatch bool) {
mdName := string(md.FullName())
if name == mdName || name == "."+mdName {
// If we have a full name match, we're done and will also
// ignore any other partial name matches.
return []protoreflect.MessageDescriptor{md}, true
}
mdLowerName := "." + strings.ToLower(mdName)
lowerName := strings.ToLower(name)
if lowerName == mdLowerName || strings.HasSuffix(mdLowerName, "."+lowerName) {
mds = append(mds, md)
}
subMessages := md.Messages()
for i := 0; i < subMessages.Len(); i++ {
md = subMessages.Get(i)
subMDs, exactMatch := lookupMessageInMD(md, name)
if exactMatch {
return subMDs, true
}
mds = append(mds, subMDs...)
}
return mds, false
}

func registryMapper(kctx *kong.DecodeContext, target reflect.Value) error {
reg, ok := target.Interface().(*registry.Files)
func fdsMapper(kctx *kong.DecodeContext, target reflect.Value) error {
fds, ok := target.Interface().(*descriptorpb.FileDescriptorSet)
if !ok {
panic("target is not a *registry.Files")
panic("target is not a *descriptorpb.FileDescriptorSet")
}
var filename string
if err := kctx.Scan.PopValueInto("file", &filename); err != nil {
return err
}
files, err := registryFiles(filename)
if err != nil {
return err
}
*reg = *files
return nil
}

func registryFiles(filename string) (*registry.Files, error) {
b, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
fds := descriptorpb.FileDescriptorSet{}
if err := proto.Unmarshal(b, &fds); err != nil {
return nil, err
return err
}
return registry.NewFiles(&fds)
return proto.Unmarshal(b, fds)
}

func isTTY() bool {
Expand Down
53 changes: 35 additions & 18 deletions cmd/pb/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,26 @@ import (
"testing"

"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/descriptorpb"
)

func newFDS(t *testing.T, filename string) *descriptorpb.FileDescriptorSet {
t.Helper()
b, err := os.ReadFile(filename)
require.NoError(t, err)
fds := descriptorpb.FileDescriptorSet{}
err = proto.Unmarshal(b, &fds)
require.NoError(t, err)
return &fds
}

func TestRunJSON(t *testing.T) {
tmpDir := t.TempDir()
files, err := registryFiles("testdata/pbtest.pb")
require.NoError(t, err)
fds := newFDS(t, "testdata/pbtest.pb")

cli := PBConfig{
Protoset: files,
Protoset: fds,
Out: filepath.Join(tmpDir, "out.json"),
MessageType: "BaseMessage",
In: `{"f": "F" }`,
Expand All @@ -34,11 +45,10 @@ func TestRunJSON(t *testing.T) {

func TestRunJSONZero(t *testing.T) {
tmpDir := t.TempDir()
files, err := registryFiles("testdata/pbtest.pb")
require.NoError(t, err)
fds := newFDS(t, "testdata/pbtest.pb")

cli := PBConfig{
Protoset: files,
Protoset: fds,
Out: filepath.Join(tmpDir, "out.json"),
MessageType: "BaseMessage",
In: `{"f": "" }`,
Expand All @@ -57,11 +67,10 @@ func TestRunJSONZero(t *testing.T) {

func TestRunPrototext(t *testing.T) {
tmpDir := t.TempDir()
files, err := registryFiles("testdata/pbtest.pb")
require.NoError(t, err)
fds := newFDS(t, "testdata/pbtest.pb")

cli := PBConfig{
Protoset: files,
Protoset: fds,
Out: filepath.Join(tmpDir, "out.txt"),
MessageType: "BaseMessage",
In: `{"f": "F" }`,
Expand All @@ -84,11 +93,10 @@ func TestRunPrototext(t *testing.T) {

func TestRunMessages(t *testing.T) {
tmpDir := t.TempDir()
files, err := registryFiles("testdata/pbtest.pb")
require.NoError(t, err)
fds := newFDS(t, "testdata/pbtest.pb")

cli := PBConfig{
Protoset: files,
Protoset: fds,
Out: filepath.Join(tmpDir, "out.json"),
In: `{"f": "F" }`,
}
Expand All @@ -105,11 +113,10 @@ func TestRunMessages(t *testing.T) {

func TestRunMessageErr(t *testing.T) {
tmpDir := t.TempDir()
files, err := registryFiles("testdata/pbtest.pb")
require.NoError(t, err)
fds := newFDS(t, "testdata/pbtest.pb")

cli := PBConfig{
Protoset: files,
Protoset: fds,
Out: filepath.Join(tmpDir, "out.json"),
In: `{"f": "F" }`,
}
Expand All @@ -124,18 +131,28 @@ func TestRunMessageErr(t *testing.T) {

func TestRunInErr(t *testing.T) {
tmpDir := t.TempDir()
files, err := registryFiles("testdata/pbtest.pb")
require.NoError(t, err)
fds := newFDS(t, "testdata/pbtest.pb")

cli := PBConfig{
Protoset: files,
Protoset: fds,
Out: filepath.Join(tmpDir, "out.json"),
MessageType: "BaseMessage",
In: `{"MISSING": "F" }`,
}
require.Error(t, cli.Run())
}

func TestWellKnown(t *testing.T) {
tmpDir := t.TempDir()
cli := PBConfig{
Out: filepath.Join(tmpDir, "out.json"),
MessageType: "Duration",
In: `"10s"`,
}
require.NoError(t, cli.Run())
requireJSONFileContent(t, `"10s"`, cli.Out)
}

func requireJSONFileContent(t *testing.T, wantStr string, gotFile string) {
t.Helper()
b, err := os.ReadFile(gotFile)
Expand Down

0 comments on commit 5a17c58

Please sign in to comment.