From 394737f2629f9019623cf5e945c13b3b93ac34b3 Mon Sep 17 00:00:00 2001 From: Ernest Micklei Date: Wed, 1 Feb 2017 23:43:21 +0100 Subject: [PATCH] work on aligned fields in formatter --- cmd/protofmt/unformatted.proto | 5 ++ enum.go | 2 +- formatter.go | 86 ++++++++++------------------------ formatter_test.go | 5 +- formatter_utils.go | 25 ++++++---- option.go | 4 +- service.go | 25 ++++++++++ service_test.go | 7 ++- 8 files changed, 79 insertions(+), 80 deletions(-) diff --git a/cmd/protofmt/unformatted.proto b/cmd/protofmt/unformatted.proto index b310c92..4cb4d54 100644 --- a/cmd/protofmt/unformatted.proto +++ b/cmd/protofmt/unformatted.proto @@ -52,6 +52,11 @@ message Message } } + +service SearchService { // comment + rpc Search ( SearchRequest ) returns ( SearchResponse ); + rpc Find ( Finder ) returns ( stream Result );} + // emptiness enum Enum {} service Service {} diff --git a/enum.go b/enum.go index 0f396af..0e46571 100644 --- a/enum.go +++ b/enum.go @@ -31,7 +31,7 @@ func (f *EnumField) Accept(v Visitor) { // columns returns printable source tokens func (f EnumField) columns() (cols []aligned) { - cols = append(cols, leftAligned(f.Name), alignedEquals, leftAligned(strconv.Itoa(f.Integer))) + cols = append(cols, leftAligned(f.Name), alignedEquals, rightAligned(strconv.Itoa(f.Integer))) if f.ValueOption != nil { cols = append(cols, f.ValueOption.columns()...) } diff --git a/formatter.go b/formatter.go index 2c6ffa8..20ea4db 100644 --- a/formatter.go +++ b/formatter.go @@ -43,24 +43,17 @@ func (f *Formatter) VisitComment(c *Comment) { func (f *Formatter) VisitEnum(e *Enum) { f.begin("enum") fmt.Fprintf(f.w, "enum %s {", e.Name) - f.indentLevel++ - f.printAsGroups(e.Elements) - f.indent(-1) + if len(e.Elements) > 0 { + f.nl() + f.indentLevel++ + f.printAsGroups(e.Elements) + f.indent(-1) + } io.WriteString(f.w, "}\n") } // VisitEnumField formats a EnumField. -func (f *Formatter) VisitEnumField(e *EnumField) { - f.begin("field") - io.WriteString(f.w, paddedTo(e.Name, 10)) - fmt.Fprintf(f.w, " = %d", e.Integer) - if e.ValueOption != nil { - io.WriteString(f.w, " ") - e.ValueOption.Accept(f) - } else { - io.WriteString(f.w, ";\n") - } -} +func (f *Formatter) VisitEnumField(e *EnumField) {} // VisitImport formats a Import. func (f *Formatter) VisitImport(i *Import) { @@ -75,37 +68,17 @@ func (f *Formatter) VisitImport(i *Import) { func (f *Formatter) VisitMessage(m *Message) { f.begin("message") fmt.Fprintf(f.w, "message %s {", m.Name) - f.newLineIf(len(m.Elements) > 0) - f.indentLevel++ - for _, each := range m.Elements { - each.Accept(f) + if len(m.Elements) > 0 { + f.nl() + f.indentLevel++ + f.printAsGroups(m.Elements) + f.indent(-1) } - f.indent(-1) io.WriteString(f.w, "}\n") } // VisitOption formats a Option. -func (f *Formatter) VisitOption(o *Option) { - if o.IsEmbedded { - io.WriteString(f.w, "[(") - } else { - f.begin("option") - io.WriteString(f.w, "option ") - } - if len(o.Name) > 0 { - io.WriteString(f.w, o.Name) - } - if o.IsEmbedded { - io.WriteString(f.w, ")") - } - io.WriteString(f.w, " = ") - io.WriteString(f.w, o.Constant.String()) - if o.IsEmbedded { - io.WriteString(f.w, "];\n") - } else { - io.WriteString(f.w, ";\n") - } -} +func (f *Formatter) VisitOption(o *Option) {} // VisitPackage formats a Package. func (f *Formatter) VisitPackage(p *Package) { @@ -117,11 +90,12 @@ func (f *Formatter) VisitPackage(p *Package) { func (f *Formatter) VisitService(s *Service) { f.begin("service") fmt.Fprintf(f.w, "service %s {", s.Name) - f.indentLevel++ - for _, each := range s.Elements { - each.Accept(f) + if len(s.Elements) > 0 { + f.nl() + f.indentLevel++ + f.printAsGroups(s.Elements) + f.indent(-1) } - f.indent(-1) io.WriteString(f.w, "}\n") } @@ -134,11 +108,12 @@ func (f *Formatter) VisitSyntax(s *Syntax) { func (f *Formatter) VisitOneof(o *Oneof) { f.begin("oneof") fmt.Fprintf(f.w, "oneof %s {", o.Name) - f.indentLevel++ - for _, each := range o.Elements { - each.Accept(f) + if len(o.Elements) > 0 { + f.nl() + f.indentLevel++ + f.printAsGroups(o.Elements) + f.indent(-1) } - f.indent(-1) io.WriteString(f.w, "}\n") } @@ -170,20 +145,7 @@ func (f *Formatter) VisitReserved(r *Reserved) { } // VisitRPC formats a RPC. -func (f *Formatter) VisitRPC(r *RPC) { - f.begin("rpc") - fmt.Fprintf(f.w, "rpc %s (", r.Name) - if r.StreamsRequest { - io.WriteString(f.w, "stream ") - } - io.WriteString(f.w, r.RequestType) - io.WriteString(f.w, ") returns (") - if r.StreamsReturns { - io.WriteString(f.w, "stream ") - } - io.WriteString(f.w, r.ReturnsType) - io.WriteString(f.w, ");\n") -} +func (f *Formatter) VisitRPC(r *RPC) {} // VisitMapField formats a MapField. func (f *Formatter) VisitMapField(m *MapField) { diff --git a/formatter_test.go b/formatter_test.go index 49bb37c..91adbf5 100644 --- a/formatter_test.go +++ b/formatter_test.go @@ -28,9 +28,8 @@ func TestPrintListOfColumns(t *testing.T) { b := new(bytes.Buffer) f := NewFormatter(b, " ") f.printListOfColumns(list) - formatted := ` -A = 1 [a =1234]; -ABC = 12 [ab=1234]; + formatted := `A = 1 [a = 1234]; +ABC = 12 [ab = 1234]; ` if got, want := b.String(), formatted; got != want { t.Errorf("got [%v] want [%v]", got, want) diff --git a/formatter_utils.go b/formatter_utils.go index 45884e0..5fe1ffc 100644 --- a/formatter_utils.go +++ b/formatter_utils.go @@ -8,7 +8,7 @@ import ( func (f *Formatter) begin(stmt string) { if f.lastStmt != stmt && len(f.lastStmt) > 0 { // not the first line // add separator because stmt is changed, unless it nested thingy - if !strings.Contains("comment", f.lastStmt) { + if !strings.Contains("comment enum service", f.lastStmt) { io.WriteString(f.w, "\n") } } @@ -28,6 +28,9 @@ type columnsPrintable interface { } func (f *Formatter) printListOfColumns(list []columnsPrintable) { + if len(list) == 0 { + return + } // collect all column values values := [][]aligned{} widths := map[int]int{} @@ -48,8 +51,10 @@ func (f *Formatter) printListOfColumns(list []columnsPrintable) { } } // now print all values - for _, each := range values { - io.WriteString(f.w, "\n") + for i, each := range values { + if i > 0 { + f.nl() + } f.indent(0) for c := 0; c < len(widths); c++ { pw := widths[c] @@ -62,9 +67,10 @@ func (f *Formatter) printListOfColumns(list []columnsPrintable) { } io.WriteString(f.w, ";") } - io.WriteString(f.w, "\n") + f.nl() } +// paddedTo return the word padding with spaces to match the length. func paddedTo(word string, length int) string { if len(word) >= length { return word @@ -72,10 +78,9 @@ func paddedTo(word string, length int) string { return word + strings.Repeat(" ", length-len(word)) } -func (f *Formatter) newLineIf(ok bool) { - if ok { - io.WriteString(f.w, "\n") - } +func (f *Formatter) nl() *Formatter { + io.WriteString(f.w, "\n") + return f } func (f *Formatter) printAsGroups(list []Visitee) { @@ -83,8 +88,8 @@ func (f *Formatter) printAsGroups(list []Visitee) { return } group := []columnsPrintable{} - lastGroupName := nameOfVisitee(list[0]) - for i := 1; i < len(list); i++ { + lastGroupName := "" + for i := 0; i < len(list); i++ { groupName := nameOfVisitee(list[i]) printable, isColumnsPrintable := list[i].(columnsPrintable) if isColumnsPrintable { diff --git a/option.go b/option.go index 15b7f22..0c8c10c 100644 --- a/option.go +++ b/option.go @@ -18,11 +18,11 @@ func (o *Option) Accept(v Visitor) { // columns returns printable source tokens func (o *Option) columns() (cols []aligned) { if !o.IsEmbedded { - cols = append(cols, leftAligned("option")) + cols = append(cols, leftAligned("option ")) } else { cols = append(cols, leftAligned(" [")) } - cols = append(cols, leftAligned(o.Name), leftAligned("="), rightAligned(o.Constant.String())) + cols = append(cols, leftAligned(o.Name), alignedEquals, rightAligned(o.Constant.String())) if o.IsEmbedded { cols = append(cols, leftAligned("]")) } diff --git a/service.go b/service.go index bc34493..c55c285 100644 --- a/service.go +++ b/service.go @@ -59,6 +59,31 @@ func (r *RPC) Accept(v Visitor) { v.VisitRPC(r) } +// columns returns printable source tokens +func (r *RPC) columns() (cols []aligned) { + cols = append(cols, + leftAligned("rpc "), + leftAligned(r.Name), + leftAligned(" (")) + stream := "" + if r.StreamsRequest { + stream = "stream " + } + cols = append(cols, + leftAligned(stream+r.RequestType), + leftAligned(") "), + leftAligned("returns"), + leftAligned(" (")) + stream = "" + if r.StreamsReturns { + stream = "stream " + } + cols = append(cols, + leftAligned(stream+r.ReturnsType), + leftAligned(")")) + return cols +} + // parse continues after reading "rpc" func (r *RPC) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() diff --git a/service_test.go b/service_test.go index 5ba9ed5..0e179b4 100644 --- a/service_test.go +++ b/service_test.go @@ -6,7 +6,6 @@ func TestService(t *testing.T) { proto := `service AccountService { // comment rpc CreateAccount (CreateAccount) returns (ServiceFault); - // comment rpc GetAccounts (stream Int64) returns (Account); }` pr, err := newParserOn(proto).Parse() @@ -14,11 +13,15 @@ func TestService(t *testing.T) { t.Fatal(err) } srv := collect(pr).Services()[0] - if got, want := len(srv.Elements), 4; got != want { + if got, want := len(srv.Elements), 3; got != want { t.Errorf("got [%v] want [%v]", got, want) } rpc1 := srv.Elements[1].(*RPC) if got, want := rpc1.Name, "CreateAccount"; got != want { t.Errorf("got [%v] want [%v]", got, want) } + rpc2 := srv.Elements[2].(*RPC) + if got, want := rpc2.Name, "GetAccounts"; got != want { + t.Errorf("got [%v] want [%v]", got, want) + } }