diff --git a/aligned.go b/aligned.go new file mode 100644 index 0000000..71bee5a --- /dev/null +++ b/aligned.go @@ -0,0 +1,25 @@ +package proto + +import "strings" + +type aligned struct { + source string + left bool +} + +var alignedEquals = leftAligned(" = ") + +func leftAligned(src string) aligned { return aligned{src, true} } +func rightAligned(src string) aligned { return aligned{src, false} } + +func (a aligned) preferredWidth() int { return len(a.source) } + +func (a aligned) formatted(width int) string { + if len(a.source) > width { + return a.source[:width] + } + if a.left { + return a.source + strings.Repeat(" ", width-len(a.source)) + } + return strings.Repeat(" ", width-len(a.source)) + a.source +} diff --git a/cmd/protofmt/main.go b/cmd/protofmt/main.go index 8cf7732..9b0852b 100644 --- a/cmd/protofmt/main.go +++ b/cmd/protofmt/main.go @@ -22,8 +22,5 @@ func main() { if err != nil { log.Fatalln("protofmt failed", err) } - f := &formatter{w: os.Stdout, indentSeparator: " "} - for _, each := range def.Elements { - each.Accept(f) - } + proto.NewFormatter(os.Stdout, " ").Format(def) } diff --git a/cmd/protofmt/unformatted.proto b/cmd/protofmt/unformatted.proto index de67d42..b310c92 100644 --- a/cmd/protofmt/unformatted.proto +++ b/cmd/protofmt/unformatted.proto @@ -39,13 +39,16 @@ message Message BILL_BAILEY = 3; } + reserved 2, 15, 9 to 11; + reserved "foo", "bar"; + map terrain = 4; enum EnumAllowingAlias { option allow_alias = true; UNKNOWN = 0; STARTED = 1; - RUNNING = 2 [(custom_option) = "hello world"]; + RUN = 2 [(custom_option) = "hello world"]; } } diff --git a/enum.go b/enum.go index 5af47ef..0f396af 100644 --- a/enum.go +++ b/enum.go @@ -1,6 +1,9 @@ package proto -import "fmt" +import ( + "fmt" + "strconv" +) // Enum definition consists of a name and an enum body. type Enum struct { @@ -26,6 +29,15 @@ func (f *EnumField) Accept(v Visitor) { v.VisitEnumField(f) } +// columns returns printable source tokens +func (f EnumField) columns() (cols []aligned) { + cols = append(cols, leftAligned(f.Name), alignedEquals, leftAligned(strconv.Itoa(f.Integer))) + if f.ValueOption != nil { + cols = append(cols, f.ValueOption.columns()...) + } + return +} + func (f *EnumField) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() if tok != tIDENT { diff --git a/cmd/protofmt/formatter.go b/formatter.go similarity index 62% rename from cmd/protofmt/formatter.go rename to formatter.go index 69ff386..2c6ffa8 100644 --- a/cmd/protofmt/formatter.go +++ b/formatter.go @@ -1,22 +1,34 @@ -package main +package proto import ( "fmt" "io" "strings" - - "github.com/emicklei/proto" ) -type formatter struct { +// Formatter visits a Proto and writes formatted source. +type Formatter struct { w io.Writer indentLevel int lastStmt string indentSeparator string } -func (f *formatter) VisitComment(c *proto.Comment) { +// NewFormatter returns a new Formatter. Only the indentation separator is configurable. +func NewFormatter(writer io.Writer, indentSeparator string) *Formatter { + return &Formatter{w: writer, indentSeparator: indentSeparator} +} + +// Format visits all proto elements and writes formatted source. +func (f *Formatter) Format(p *Proto) { + for _, each := range p.Elements { + each.Accept(f) + } +} + +// VisitComment formats a Comment. +func (f *Formatter) VisitComment(c *Comment) { f.begin("comment") if c.IsMultiline() { fmt.Fprintln(f.w, "/*") @@ -27,18 +39,18 @@ func (f *formatter) VisitComment(c *proto.Comment) { } } -func (f *formatter) VisitEnum(e *proto.Enum) { +// VisitEnum formats a Enum. +func (f *Formatter) VisitEnum(e *Enum) { f.begin("enum") fmt.Fprintf(f.w, "enum %s {", e.Name) f.indentLevel++ - for _, each := range e.Elements { - each.Accept(f) - } + f.printAsGroups(e.Elements) f.indent(-1) io.WriteString(f.w, "}\n") } -func (f *formatter) VisitEnumField(e *proto.EnumField) { +// VisitEnumField formats a EnumField. +func (f *Formatter) VisitEnumField(e *EnumField) { f.begin("field") io.WriteString(f.w, paddedTo(e.Name, 10)) fmt.Fprintf(f.w, " = %d", e.Integer) @@ -50,7 +62,8 @@ func (f *formatter) VisitEnumField(e *proto.EnumField) { } } -func (f *formatter) VisitImport(i *proto.Import) { +// VisitImport formats a Import. +func (f *Formatter) VisitImport(i *Import) { f.begin("import") if len(i.Kind) > 0 { fmt.Fprintf(f.w, "import %s ", i.Kind) @@ -58,7 +71,8 @@ func (f *formatter) VisitImport(i *proto.Import) { fmt.Fprintf(f.w, "import %q;\n", i.Filename) } -func (f *formatter) VisitMessage(m *proto.Message) { +// VisitMessage formats a Message. +func (f *Formatter) VisitMessage(m *Message) { f.begin("message") fmt.Fprintf(f.w, "message %s {", m.Name) f.newLineIf(len(m.Elements) > 0) @@ -70,7 +84,8 @@ func (f *formatter) VisitMessage(m *proto.Message) { io.WriteString(f.w, "}\n") } -func (f *formatter) VisitOption(o *proto.Option) { +// VisitOption formats a Option. +func (f *Formatter) VisitOption(o *Option) { if o.IsEmbedded { io.WriteString(f.w, "[(") } else { @@ -92,12 +107,14 @@ func (f *formatter) VisitOption(o *proto.Option) { } } -func (f *formatter) VisitPackage(p *proto.Package) { +// VisitPackage formats a Package. +func (f *Formatter) VisitPackage(p *Package) { f.begin("package") fmt.Fprintf(f.w, "package %s;\n", p.Name) } -func (f *formatter) VisitService(s *proto.Service) { +// VisitService formats a Service. +func (f *Formatter) VisitService(s *Service) { f.begin("service") fmt.Fprintf(f.w, "service %s {", s.Name) f.indentLevel++ @@ -108,11 +125,13 @@ func (f *formatter) VisitService(s *proto.Service) { io.WriteString(f.w, "}\n") } -func (f *formatter) VisitSyntax(s *proto.Syntax) { +// VisitSyntax formats a Syntax. +func (f *Formatter) VisitSyntax(s *Syntax) { fmt.Fprintf(f.w, "syntax = %q;\n\n", s.Value) } -func (f *formatter) VisitOneof(o *proto.Oneof) { +// VisitOneof formats a Oneof. +func (f *Formatter) VisitOneof(o *Oneof) { f.begin("oneof") fmt.Fprintf(f.w, "oneof %s {", o.Name) f.indentLevel++ @@ -123,7 +142,8 @@ func (f *formatter) VisitOneof(o *proto.Oneof) { io.WriteString(f.w, "}\n") } -func (f *formatter) VisitOneofField(o *proto.OneOfField) { +// VisitOneofField formats a OneofField. +func (f *Formatter) VisitOneofField(o *OneOfField) { f.begin("oneoffield") fmt.Fprintf(f.w, "%s %s = %d", o.Type, o.Name, o.Sequence) for _, each := range o.Options { @@ -132,9 +152,10 @@ func (f *formatter) VisitOneofField(o *proto.OneOfField) { io.WriteString(f.w, ";\n") } -func (f *formatter) VisitReserved(r *proto.Reserved) { +// VisitReserved formats a Reserved. +func (f *Formatter) VisitReserved(r *Reserved) { f.begin("reserved") - io.WriteString(f.w, "reserved") + io.WriteString(f.w, "reserved ") if len(r.Ranges) > 0 { io.WriteString(f.w, r.Ranges) } else { @@ -148,7 +169,8 @@ func (f *formatter) VisitReserved(r *proto.Reserved) { io.WriteString(f.w, ";\n") } -func (f *formatter) VisitRPC(r *proto.RPC) { +// VisitRPC formats a RPC. +func (f *Formatter) VisitRPC(r *RPC) { f.begin("rpc") fmt.Fprintf(f.w, "rpc %s (", r.Name) if r.StreamsRequest { @@ -163,12 +185,14 @@ func (f *formatter) VisitRPC(r *proto.RPC) { io.WriteString(f.w, ");\n") } -func (f *formatter) VisitMapField(m *proto.MapField) { +// VisitMapField formats a MapField. +func (f *Formatter) VisitMapField(m *MapField) { f.begin("map") fmt.Fprintf(f.w, "map<%s,%s> %s = %d;\n", m.KeyType, m.Type, m.Name, m.Sequence) } -func (f *formatter) VisitNormalField(f1 *proto.NormalField) { +// VisitNormalField formats a NormalField. +func (f *Formatter) VisitNormalField(f1 *NormalField) { f.begin("field") if f1.Repeated { io.WriteString(f.w, "repeated ") @@ -178,36 +202,3 @@ func (f *formatter) VisitNormalField(f1 *proto.NormalField) { } fmt.Fprintf(f.w, "%s %s = %d;\n", f1.Type, f1.Name, f1.Sequence) } - -// Utils - -func (f *formatter) begin(stmt string) { - if f.lastStmt != stmt && len(f.lastStmt) > 0 { // not the first line - // add separator because stmt is changed, unless it nested thingy - if !strings.Contains("comment", f.lastStmt) { - io.WriteString(f.w, "\n") - } - } - f.indent(0) - f.lastStmt = stmt -} - -func (f *formatter) indent(diff int) { - f.indentLevel += diff - for i := 0; i < f.indentLevel; i++ { - io.WriteString(f.w, f.indentSeparator) - } -} - -func paddedTo(word string, length int) string { - if len(word) >= length { - return word - } - return word + strings.Repeat(" ", length-len(word)) -} - -func (f *formatter) newLineIf(ok bool) { - if ok { - io.WriteString(f.w, "\n") - } -} diff --git a/formatter_test.go b/formatter_test.go new file mode 100644 index 0000000..49bb37c --- /dev/null +++ b/formatter_test.go @@ -0,0 +1,38 @@ +package proto + +import ( + "bytes" + "testing" +) + +func TestPrintListOfColumns(t *testing.T) { + e0 := new(EnumField) + e0.Name = "A" + e0.Integer = 1 + op0 := new(Option) + op0.IsEmbedded = true + op0.Name = "a" + op0.Constant = Literal{Source: "1234"} + e0.ValueOption = op0 + + e1 := new(EnumField) + e1.Name = "ABC" + e1.Integer = 12 + op1 := new(Option) + op1.IsEmbedded = true + op1.Name = "ab" + op1.Constant = Literal{Source: "1234"} + e1.ValueOption = op1 + + list := []columnsPrintable{e0, e1} + b := new(bytes.Buffer) + f := NewFormatter(b, " ") + f.printListOfColumns(list) + formatted := ` +A = 1 [a =1234]; +ABC = 12 [ab=1234]; +` + if got, want := b.String(), formatted; got != want { + t.Errorf("got [%v] want [%v]", got, want) + } +} diff --git a/formatter_utils.go b/formatter_utils.go new file mode 100644 index 0000000..45884e0 --- /dev/null +++ b/formatter_utils.go @@ -0,0 +1,108 @@ +package proto + +import ( + "io" + "strings" +) + +func (f *Formatter) begin(stmt string) { + if f.lastStmt != stmt && len(f.lastStmt) > 0 { // not the first line + // add separator because stmt is changed, unless it nested thingy + if !strings.Contains("comment", f.lastStmt) { + io.WriteString(f.w, "\n") + } + } + f.indent(0) + f.lastStmt = stmt +} + +func (f *Formatter) indent(diff int) { + f.indentLevel += diff + for i := 0; i < f.indentLevel; i++ { + io.WriteString(f.w, f.indentSeparator) + } +} + +type columnsPrintable interface { + columns() (cols []aligned) +} + +func (f *Formatter) printListOfColumns(list []columnsPrintable) { + // collect all column values + values := [][]aligned{} + widths := map[int]int{} + for _, each := range list { + cols := each.columns() + values = append(values, cols) + // update max widths per column + for i, other := range cols { + pw := other.preferredWidth() + w, ok := widths[i] + if ok { + if pw > w { + widths[i] = pw + } + } else { + widths[i] = pw + } + } + } + // now print all values + for _, each := range values { + io.WriteString(f.w, "\n") + f.indent(0) + for c := 0; c < len(widths); c++ { + pw := widths[c] + // only print if there is a value + if c < len(each) { + // using space padding to match the max width + src := each[c].formatted(pw) + io.WriteString(f.w, src) + } + } + io.WriteString(f.w, ";") + } + io.WriteString(f.w, "\n") +} + +func paddedTo(word string, length int) string { + if len(word) >= length { + return word + } + return word + strings.Repeat(" ", length-len(word)) +} + +func (f *Formatter) newLineIf(ok bool) { + if ok { + io.WriteString(f.w, "\n") + } +} + +func (f *Formatter) printAsGroups(list []Visitee) { + if len(list) == 0 { + return + } + group := []columnsPrintable{} + lastGroupName := nameOfVisitee(list[0]) + for i := 1; i < len(list); i++ { + groupName := nameOfVisitee(list[i]) + printable, isColumnsPrintable := list[i].(columnsPrintable) + if isColumnsPrintable { + if lastGroupName == groupName { + // collect in group + group = append(group, printable) + } else { + // print current group + f.printListOfColumns(group) + lastGroupName = groupName + // begin new group + group = []columnsPrintable{printable} + } + } else { + // not printable in group + list[i].Accept(f) + } + } + // print last group + f.printListOfColumns(group) +} diff --git a/option.go b/option.go index 3866113..15b7f22 100644 --- a/option.go +++ b/option.go @@ -15,6 +15,20 @@ func (o *Option) Accept(v Visitor) { v.VisitOption(o) } +// columns returns printable source tokens +func (o *Option) columns() (cols []aligned) { + if !o.IsEmbedded { + cols = append(cols, leftAligned("option")) + } else { + cols = append(cols, leftAligned(" [")) + } + cols = append(cols, leftAligned(o.Name), leftAligned("="), rightAligned(o.Constant.String())) + if o.IsEmbedded { + cols = append(cols, leftAligned("]")) + } + return +} + // parse reads an Option body // ( ident | "(" fullIdent ")" ) { "." ident } "=" constant ";" func (o *Option) parse(p *Parser) error { diff --git a/visitor.go b/visitor.go index f8c3ef7..addfc11 100644 --- a/visitor.go +++ b/visitor.go @@ -23,3 +23,29 @@ type Visitor interface { type Visitee interface { Accept(v Visitor) } + +type reflector struct { + name string +} + +func (r *reflector) VisitMessage(m *Message) { r.name = "Message" } +func (r *reflector) VisitService(v *Service) { r.name = "Service" } +func (r *reflector) VisitSyntax(s *Syntax) { r.name = "Syntax" } +func (r *reflector) VisitPackage(p *Package) { r.name = "Package" } +func (r *reflector) VisitOption(o *Option) { r.name = "Option" } +func (r *reflector) VisitImport(i *Import) { r.name = "Import" } +func (r *reflector) VisitNormalField(i *NormalField) { r.name = "NormalField" } +func (r *reflector) VisitEnumField(i *EnumField) { r.name = "EnumField" } +func (r *reflector) VisitEnum(e *Enum) { r.name = "Enum" } +func (r *reflector) VisitComment(e *Comment) { r.name = "Comment" } +func (r *reflector) VisitOneof(o *Oneof) { r.name = "Oneof" } +func (r *reflector) VisitOneofField(o *OneOfField) { r.name = "OneOfField" } +func (r *reflector) VisitReserved(rs *Reserved) { r.name = "Reserved" } +func (r *reflector) VisitRPC(rpc *RPC) { r.name = "RPC" } +func (r *reflector) VisitMapField(f *MapField) { r.name = "MapField" } + +func nameOfVisitee(e Visitee) string { + r := new(reflector) + e.Accept(r) + return r.name +}