diff --git a/cmd/proto3fmt/example0.proto b/cmd/proto3fmt/example0.proto new file mode 100644 index 0000000..92ada2d --- /dev/null +++ b/cmd/proto3fmt/example0.proto @@ -0,0 +1,43 @@ +// example0 +syntax = "proto3"; + + + + + // using Any +import "google/protobuf/any.proto"; + +import public "testdata/test.proto"; + + + +/* This pkg + */ +package here.proto3_proto ; + + +// from a bottle +message Message +{ + + string name =1; +// this is thing + google.protobuf.Any anything = 2 [packed=true, danger=false]; + + repeated + Message + children + = 3; + + + enum Humour { + // something we dont know + UNKNOWN = 0; + PUNS = 1; + SLAPSTICK = 2; + /* who is this? */ + BILL_BAILEY = 3; + } + + map terrain = 4; +} \ No newline at end of file diff --git a/cmd/proto3fmt/formatter.go b/cmd/proto3fmt/formatter.go new file mode 100644 index 0000000..5aac0bf --- /dev/null +++ b/cmd/proto3fmt/formatter.go @@ -0,0 +1,128 @@ +package main + +import ( + "errors" + "fmt" + "io" + + "strings" + + "github.com/emicklei/proto3" +) + +type formatter struct { + w io.Writer + indentLevel int + lastStmt string + indentSeparator string +} + +func (f *formatter) VisitComment(c *proto3.Comment) { + f.begin("comment") + if c.IsMultiline() { + fmt.Fprintln(f.w, "/*") + fmt.Fprint(f.w, strings.TrimSpace(c.Message)) + fmt.Fprintf(f.w, "\n*/\n") + } else { + fmt.Fprintf(f.w, "//%s\n", c.Message) + } +} + +func (f *formatter) VisitEnum(e *proto3.Enum) { + f.begin("enum") + fmt.Fprintf(f.w, "enum %s {\n", e.Name) + f.indentLevel++ + for _, each := range e.Elements { + each.Accept(f) + } + f.indent(-1) + io.WriteString(f.w, "}\n") +} + +func (f *formatter) VisitEnumField(e *proto3.EnumField) { + f.begin("field") + io.WriteString(f.w, paddedTo(e.Name, 10)) + if e.ValueOption != nil { + e.ValueOption.Accept(f) + } + fmt.Fprintf(f.w, " = %d;\n", e.Integer) +} + +func (f *formatter) VisitField(f1 *proto3.Field) { + f.begin("field") + if f1.Repeated { + io.WriteString(f.w, "repeated ") + } + fmt.Fprintf(f.w, "%s %s = %d;\n", f1.Type, f1.Name, f1.Sequence) +} + +func (f *formatter) VisitImport(i *proto3.Import) { + f.begin("import") + if len(i.Kind) > 0 { + fmt.Fprintf(f.w, "%s ", i.Kind) + } + fmt.Fprintf(f.w, "%q;\n", i.Filename) +} + +func (f *formatter) VisitMessage(m *proto3.Message) { + f.begin("message") + fmt.Fprintf(f.w, "message %s {\n", m.Name) + f.indentLevel++ + for _, each := range m.Elements { + each.Accept(f) + } + f.indentLevel++ + io.WriteString(f.w, "}\n") +} + +func (f *formatter) VisitOption(o *proto3.Option) { + panic(errors.New("Not implemented")) +} + +func (f *formatter) VisitPackage(p *proto3.Package) { + f.begin("package") + fmt.Fprintf(f.w, "package %s;\n", p.Name) +} + +func (f *formatter) VisitService(s *proto3.Service) { + panic(errors.New("Not implemented")) +} + +func (f *formatter) VisitSyntax(s *proto3.Syntax) { + fmt.Fprintf(f.w, "syntax = %q;\n\n", s.Value) +} + +func (f *formatter) VisitOneof(o *proto3.Oneof) { + panic(errors.New("Not implemented")) +} + +func (f *formatter) VisitOneofField(o *proto3.OneOfField) { + panic(errors.New("Not implemented")) +} + +// Utils + +func (f *formatter) begin(stmt string) { + if f.lastStmt != stmt && len(f.lastStmt) > 0 { // not the first line + // add separator because stmt is changed, unless it was comment or a nested thingy + if !strings.Contains("comment message enum", f.lastStmt) { + io.WriteString(f.w, "\n") + } + } + f.indent(0) + f.lastStmt = stmt +} + +func (f *formatter) indent(diff int) { + f.indentLevel += diff + for i := 0; i < f.indentLevel; i++ { + io.WriteString(f.w, f.indentSeparator) + } +} + +func paddedTo(word string, length int) string { + if len(word) >= length { + return word + } + return word + strings.Repeat(" ", length-len(word)) +} diff --git a/cmd/proto3fmt/main.go b/cmd/proto3fmt/main.go index ce83850..4d86f14 100644 --- a/cmd/proto3fmt/main.go +++ b/cmd/proto3fmt/main.go @@ -7,11 +7,16 @@ import ( "github.com/emicklei/proto3" ) +// go run *.go < example1.proto +// go run *.go < example0.proto func main() { p := proto3.NewParser(os.Stdin) def, err := p.Parse() if err != nil { - log.Fatalln("proto3fmt failed:", p.Line(), err) + log.Fatalln("proto3fmt failed, on line", p.Line(), err) + } + f := &formatter{w: os.Stdout, indentSeparator: " "} + for _, each := range def.Elements { + each.Accept(f) } - log.Printf("%#v", def) } diff --git a/enum.go b/enum.go index ab9f7b4..7b76361 100644 --- a/enum.go +++ b/enum.go @@ -7,10 +7,14 @@ import ( // Enum definition consists of a name and an enum body. type Enum struct { - Line int - Name string - Options []*Option - EnumFields []*EnumField + Line int + Name string + Elements []Visitee +} + +// Accept dispatches the call to the visitor. +func (e *Enum) Accept(v Visitor) { + v.VisitEnum(e) } // EnumField is part of the body of an Enum. @@ -20,11 +24,17 @@ type EnumField struct { ValueOption *Option } +// Accept dispatches the call to the visitor. +func (f *EnumField) Accept(v Visitor) { + v.VisitEnumField(f) +} + func (f *EnumField) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() if tok != tIDENT { return fmt.Errorf("found %q, expected identifier", lit) } + f.Name = lit tok, lit = p.scanIgnoreWhitespace() if tok != tEQUALS { return fmt.Errorf("found %q, expected =", lit) @@ -64,16 +74,18 @@ func (e *Enum) parse(p *Parser) error { for { tok, lit = p.scanIgnoreWhitespace() switch tok { - case tRIGHTCURLY: - goto done - case tSEMICOLON: + case tCOMMENT: + e.Elements = append(e.Elements, p.newComment(lit)) case tOPTION: v := new(Option) err := v.parse(p) if err != nil { return err } - e.Options = append(e.Options, v) + e.Elements = append(e.Elements, v) + case tRIGHTCURLY: + goto done + case tSEMICOLON: default: p.unscan() f := new(EnumField) @@ -81,7 +93,7 @@ func (e *Enum) parse(p *Parser) error { if err != nil { return err } - e.EnumFields = append(e.EnumFields, f) + e.Elements = append(e.Elements, f) } } done: diff --git a/enum_test.go b/enum_test.go index c37d6dd..d2f9cd3 100644 --- a/enum_test.go +++ b/enum_test.go @@ -36,13 +36,14 @@ enum EnumAllowingAlias { if err != nil { t.Fatal(err) } - if got, want := len(pr.Enums), 1; got != want { + if got, want := len(collect(pr).Enums()), 1; got != want { t.Errorf("got [%v] want [%v]", got, want) } - if got, want := len(pr.Enums[0].EnumFields), 3; got != want { + if got, want := len(collect(pr).Enums()[0].Elements), 4; got != want { t.Errorf("got [%v] want [%v]", got, want) } - if got, want := pr.Enums[0].EnumFields[0].Integer, 0; got != want { + e := collect(pr).Enums()[0].Elements[1].(*EnumField) + if got, want := e.Integer, 0; got != want { t.Errorf("got [%v] want [%v]", got, want) } } diff --git a/field.go b/field.go index 75a5c1a..7e14e73 100644 --- a/field.go +++ b/field.go @@ -3,7 +3,6 @@ package proto3 import ( "fmt" "strconv" - "strings" ) // Field is a message field. @@ -11,8 +10,13 @@ type Field struct { Name string Type string Repeated bool - Messages []*Message Sequence int + Messages []*Message +} + +// Accept dispatches the call to the visitor. +func (f *Field) Accept(v Visitor) { + v.VisitField(f) } func (f *Field) parse(p *Parser) error { @@ -20,14 +24,9 @@ func (f *Field) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() switch tok { case tIDENT: - // normal type? - if strings.Contains(typeTokens, lit) { - f.Type = lit - return parseNormalField(f, p) - } - //if tok == ONEOF {} - //if tok == ONEOFFIELD {} - case tMESSAGE: + f.Type = lit + return parseNormalField(f, p) + case tMESSAGE: // TODO here? m := new(Message) err := m.parse(p) if err != nil { @@ -37,6 +36,14 @@ func (f *Field) parse(p *Parser) error { case tREPEATED: f.Repeated = true return f.parse(p) + case tMAP: + tok, lit := p.scanIgnoreWhitespace() + if tLESS != tok { + return fmt.Errorf("found %q, expected <", lit) + } + kvtypes := p.s.scanUntil('>') + f.Type = fmt.Sprintf("map<%s>", kvtypes) + return parseNormalField(f, p) default: goto done } diff --git a/import.go b/import.go index 5895a81..d271955 100644 --- a/import.go +++ b/import.go @@ -9,6 +9,11 @@ type Import struct { Kind string // weak, public, } +// Accept dispatches the call to the visitor. +func (i *Import) Accept(v Visitor) { + v.VisitImport(i) +} + func (i *Import) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() i.Line = p.s.line diff --git a/message.go b/message.go index 1b519ae..90dc155 100644 --- a/message.go +++ b/message.go @@ -4,12 +4,13 @@ import "fmt" // Message consists of a message name and a message body. type Message struct { - Line int - Comments []*Comment + Name string + Elements []Visitee +} - Name string - Fields []*Field - Enums []*Enum +// Accept dispatches the call to the visitor. +func (m *Message) Accept(v Visitor) { + v.VisitMessage(m) } func (m *Message) parse(p *Parser) error { @@ -26,7 +27,7 @@ func (m *Message) parse(p *Parser) error { tok, lit = p.scanIgnoreWhitespace() switch tok { case tCOMMENT: - m.Comments = append(m.Comments, p.newComment(lit)) + m.Elements = append(m.Elements, p.newComment(lit)) case tRIGHTCURLY: goto done case tSEMICOLON: @@ -36,15 +37,24 @@ func (m *Message) parse(p *Parser) error { if err != nil { return err } - m.Enums = append(m.Enums, e) + m.Elements = append(m.Elements, e) + case tONEOF: + o := new(Oneof) + err := o.parse(p) + if err != nil { + return err + } + m.Elements = append(m.Elements, o) default: + // tFIELD + // tMAP p.unscan() f := new(Field) err := f.parse(p) if err != nil { return err } - m.Fields = append(m.Fields, f) + m.Elements = append(m.Elements, f) } } done: diff --git a/message_test.go b/message_test.go index 7ae5c86..c3bc7e2 100644 --- a/message_test.go +++ b/message_test.go @@ -37,10 +37,25 @@ func TestMessageWithFieldsAndComments(t *testing.T) { if got, want := m.Name, "AccountOut"; got != want { t.Errorf("got [%v] want [%v]", got, want) } - if got, want := len(m.Fields), 2; got != want { + if got, want := len(m.Elements), 4; got != want { t.Errorf("got [%v] want [%v]", got, want) } - if got, want := len(m.Comments), 2; got != want { - t.Errorf("got [%v] want [%v]", got, want) +} + +func TestOneOf(t *testing.T) { + proto := ` + message Sample { + oneof foo { + string name = 4; + SubMessage sub_message = 9; + } + } +` + p := NewParser(strings.NewReader(proto)) + p.scanIgnoreWhitespace() // consume first token + m := new(Message) + err := m.parse(p) + if err != nil { + t.Fatal(err) } } diff --git a/oneof.go b/oneof.go new file mode 100644 index 0000000..1eecfad --- /dev/null +++ b/oneof.go @@ -0,0 +1,84 @@ +package proto3 + +import ( + "fmt" + "strconv" +) + +// Oneof is a field alternate. +type Oneof struct { + Name string + Elements []Visitee +} + +func (o *Oneof) parse(p *Parser) error { + tok, lit := p.scanIgnoreWhitespace() + if tok != tIDENT { + return fmt.Errorf("found %q, expected identifier", lit) + } + o.Name = lit + tok, lit = p.scanIgnoreWhitespace() + if tok != tLEFTCURLY { + return fmt.Errorf("found %q, expected {", lit) + } + for { + tok, lit := p.scanIgnoreWhitespace() + if tRIGHTCURLY == tok { + break + } + if tIDENT == tok { + f := new(OneOfField) + f.Type = lit + err := f.parse(p) + if err != nil { + return err + } + o.Elements = append(o.Elements, f) + } + } + return nil +} + +// Accept dispatches the call to the visitor. +func (o *Oneof) Accept(v Visitor) { + v.VisitOneof(o) +} + +// OneOfField is part of Oneof. +type OneOfField struct { + Name string + Type string + Sequence int + Options []*Option +} + +// Accept dispatches the call to the visitor. +func (o *OneOfField) Accept(v Visitor) { + v.VisitOneofField(o) +} + +func (o *OneOfField) parse(p *Parser) error { + tok, lit := p.scanIgnoreWhitespace() + if tok != tIDENT { + return fmt.Errorf("found %q, expected identifier", lit) + } + o.Name = lit + tok, lit = p.scanIgnoreWhitespace() + if tok != tEQUALS { + return fmt.Errorf("found %q, expected =", lit) + } + _, lit = p.scanIgnoreWhitespace() + i, err := strconv.Atoi(lit) + if err != nil { + return fmt.Errorf("found %q, expected sequence number", lit) + } + o.Sequence = i + tok, lit = p.scanIgnoreWhitespace() + if tLEFTSQUARE == tok { + // TODO + p.s.scanUntil(']') + } else { + p.unscan() + } + return nil +} diff --git a/oneof_test.go b/oneof_test.go new file mode 100644 index 0000000..7349752 --- /dev/null +++ b/oneof_test.go @@ -0,0 +1,36 @@ +package proto3 + +import ( + "strings" + "testing" +) + +func TestOneof(t *testing.T) { + proto := `oneof foo { + string name = 4; + SubMessage sub_message = 9 [options=none]; +}` + p := NewParser(strings.NewReader(proto)) + p.scanIgnoreWhitespace() // consume first token + o := new(Oneof) + err := o.parse(p) + if err != nil { + t.Fatal(err) + } + if got, want := o.Name, "foo"; got != want { + t.Errorf("got [%v] want [%v]", got, want) + } + if got, want := len(o.Elements), 2; got != want { + t.Errorf("got [%v] want [%v]", got, want) + } + second := o.Elements[1].(*OneOfField) + if got, want := second.Name, "sub_message"; got != want { + t.Errorf("got [%v] want [%v]", got, want) + } + if got, want := second.Type, "SubMessage"; got != want { + t.Errorf("got [%v] want [%v]", got, want) + } + if got, want := second.Sequence, 9; got != want { + t.Errorf("got [%v] want [%v]", got, want) + } +} diff --git a/option.go b/option.go index 2e986a2..6aaf0c0 100644 --- a/option.go +++ b/option.go @@ -10,8 +10,8 @@ type Option struct { Boolean bool } -// accept dispatches the call to the visitor. -func (o *Option) accept(v Visitor) { +// Accept dispatches the call to the visitor. +func (o *Option) Accept(v Visitor) { v.VisitOption(o) } diff --git a/package.go b/package.go index 9c887df..d9ea315 100644 --- a/package.go +++ b/package.go @@ -1,15 +1,22 @@ package proto3 +import "fmt" + // Package specifies the namespace for all proto elements. type Package struct { Name string } -// accept dispatches the call to the visitor. -func (p *Package) accept(v Visitor) { +// Accept dispatches the call to the visitor. +func (p *Package) Accept(v Visitor) { v.VisitPackage(p) } func (p *Package) parse(pr *Parser) error { + tok, lit := pr.scanIgnoreWhitespace() + if tIDENT != tok { + return fmt.Errorf("found %q, expected identifier", lit) + } + p.Name = lit return nil } diff --git a/parser_test.go b/parser_test.go index 3c5ba39..b76bb31 100644 --- a/parser_test.go +++ b/parser_test.go @@ -16,7 +16,7 @@ func TestParseComment(t *testing.T) { if err != nil { t.Fatal(err) } - if got, want := len(pr.Comments), 2; got != want { + if got, want := len(collect(pr).Comments()), 2; got != want { t.Errorf("got [%v] want [%v]", got, want) } } diff --git a/proto.go b/proto.go index 17b9e24..3da5ae2 100644 --- a/proto.go +++ b/proto.go @@ -4,12 +4,7 @@ import "strings" // Proto represents a .proto definition type Proto struct { - Syntax *Syntax - Imports []*Import - Enums []*Enum - Services []*Service - Messages []*Message - Comments []*Comment + Elements []Visitee } // Comment holds a message and line number. @@ -18,6 +13,11 @@ type Comment struct { Message string } +// Accept dispatches the call to the visitor. +func (c *Comment) Accept(v Visitor) { + v.VisitComment(c) +} + // IsMultiline returns whether its message has one or more lineends. func (c Comment) IsMultiline() bool { return strings.Contains(c.Message, "\n") @@ -28,38 +28,45 @@ func (proto *Proto) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() switch tok { case tCOMMENT: - proto.Comments = append(proto.Comments, p.newComment(lit)) + proto.Elements = append(proto.Elements, p.newComment(lit)) case tSYNTAX: s := new(Syntax) if err := s.parse(p); err != nil { return err } - proto.Syntax = s + proto.Elements = append(proto.Elements, s) case tIMPORT: im := new(Import) if err := im.parse(p); err != nil { return err } - proto.Imports = append(proto.Imports, im) + proto.Elements = append(proto.Elements, im) case tENUM: enum := new(Enum) if err := enum.parse(p); err != nil { return err } - proto.Enums = append(proto.Enums, enum) + proto.Elements = append(proto.Elements, enum) case tSERVICE: service := new(Service) err := service.parse(p) if err != nil { return err } - proto.Services = append(proto.Services, service) + proto.Elements = append(proto.Elements, service) + case tPACKAGE: + pkg := new(Package) + err := pkg.parse(p) + if err != nil { + return err + } + proto.Elements = append(proto.Elements, pkg) case tMESSAGE: msg := new(Message) if err := msg.parse(p); err != nil { return err } - proto.Messages = append(proto.Messages, msg) + proto.Elements = append(proto.Elements, msg) case tEOF: return nil } diff --git a/scanner.go b/scanner.go index 364cd63..bf70688 100644 --- a/scanner.go +++ b/scanner.go @@ -61,6 +61,8 @@ func (s *scanner) scan() (tok token, lit string) { return tRIGHTSQUARE, string(ch) case '/': return tCOMMENT, s.scanComment() + case '<': + return tLESS, string(ch) } return tILLEGAL, string(ch) } @@ -158,6 +160,10 @@ func (s *scanner) scanIdent() (tok token, lit string) { return tWEAK, buf.String() case "public": return tPUBLIC, buf.String() + case "map": + return tMAP, buf.String() + case "oneof": + return tONEOF, buf.String() } // Otherwise return as a regular identifier. diff --git a/service.go b/service.go index 6608756..46dfa1e 100644 --- a/service.go +++ b/service.go @@ -9,8 +9,8 @@ type Service struct { RPCalls []*RPcall } -// accept dispatches the call to the visitor. -func (s *Service) accept(v Visitor) { +// Accept dispatches the call to the visitor. +func (s *Service) Accept(v Visitor) { v.VisitService(s) } diff --git a/service_test.go b/service_test.go index f5b9d53..a1b7863 100644 --- a/service_test.go +++ b/service_test.go @@ -11,7 +11,7 @@ func TestService(t *testing.T) { if err != nil { t.Fatal(err) } - if got, want := stmt.Services[0].Name, "AccountService"; got != want { + if got, want := collect(stmt).Services()[0].Name, "AccountService"; got != want { t.Errorf("got [%v] want [%v]", got, want) } } @@ -25,7 +25,7 @@ func TestServiceWithRPCs(t *testing.T) { if err != nil { t.Fatal(err) } - if got, want := len(stmt.Services[0].RPCalls), 2; got != want { + if got, want := len(collect(stmt).Services()[0].RPCalls), 2; got != want { t.Errorf("got [%v] want [%v]", got, want) } } diff --git a/syntax.go b/syntax.go index 0064487..e7fb987 100644 --- a/syntax.go +++ b/syntax.go @@ -7,7 +7,8 @@ type Syntax struct { Value string } -func (s *Syntax) accept(v Visitor) { +// Accept dispatches the call to the visitor. +func (s *Syntax) Accept(v Visitor) { v.VisitSyntax(s) } diff --git a/syntax_test.go b/syntax_test.go index e7eacf4..cd1cadf 100644 --- a/syntax_test.go +++ b/syntax_test.go @@ -32,14 +32,15 @@ func TestCommentAroundSyntax(t *testing.T) { if err != nil { t.Fatal(err) } - if got, want := len(r.Comments), 4; got != want { + comments := collect(r).Comments() + if got, want := len(comments), 4; got != want { t.Errorf("got [%v] want [%v]", got, want) } for i := 1; i <= 4; i++ { - if got, want := r.Comments[i-1].Message, " comment"+strconv.Itoa(i); got != want { + if got, want := comments[i-1].Message, " comment"+strconv.Itoa(i); got != want { t.Errorf("got [%v] want [%v]", got, want) } - if got, want := r.Comments[i-1].Line, i+1; got != want { + if got, want := comments[i-1].Line, i+1; got != want { t.Errorf("got [%v] want [%v]", got, want) } } diff --git a/token.go b/token.go index 0a37472..eb9d692 100644 --- a/token.go +++ b/token.go @@ -25,6 +25,7 @@ const ( tLEFTSQUARE // [ tRIGHTSQUARE // ] tCOMMENT // / + tLESS // < // Keywords tSYNTAX @@ -41,7 +42,6 @@ const ( // special fields tONEOF - tONEOFFIELD tMAP tRESERVED tENUM diff --git a/visitor.go b/visitor.go index 1259f09..cd99423 100644 --- a/visitor.go +++ b/visitor.go @@ -7,6 +7,13 @@ type Visitor interface { VisitSyntax(s *Syntax) VisitPackage(p *Package) VisitOption(o *Option) + VisitImport(i *Import) + VisitField(i *Field) + VisitEnumField(i *EnumField) + VisitEnum(e *Enum) + VisitComment(e *Comment) + VisitOneof(o *Oneof) + VisitOneofField(o *OneOfField) } // Visitee is implemented by all Proto elements. diff --git a/visitor_test.go b/visitor_test.go new file mode 100644 index 0000000..a26e069 --- /dev/null +++ b/visitor_test.go @@ -0,0 +1,45 @@ +package proto3 + +type collector struct { + proto *Proto +} + +func collect(p *Proto) collector { + return collector{p} +} + +func (c collector) Comments() (list []*Comment) { + for _, each := range c.proto.Elements { + if c, ok := each.(*Comment); ok { + list = append(list, c) + } + } + return +} + +func (c collector) Enums() (list []*Enum) { + for _, each := range c.proto.Elements { + if c, ok := each.(*Enum); ok { + list = append(list, c) + } + } + return +} + +func (c collector) Messages() (list []*Message) { + for _, each := range c.proto.Elements { + if c, ok := each.(*Message); ok { + list = append(list, c) + } + } + return +} + +func (c collector) Services() (list []*Service) { + for _, each := range c.proto.Elements { + if c, ok := each.(*Service); ok { + list = append(list, c) + } + } + return +}