diff --git a/aligned.go b/aligned.go index 70fede1..3b4e10a 100644 --- a/aligned.go +++ b/aligned.go @@ -12,6 +12,7 @@ var ( alignedShortEquals = leftAligned("=") alignedSpace = leftAligned(" ") alignedComma = leftAligned(",") + alignedEmpty = leftAligned("") ) func leftAligned(src string) aligned { return aligned{src, true} } diff --git a/cmd/protofmt/main.go b/cmd/protofmt/main.go index 9b0852b..340e67a 100644 --- a/cmd/protofmt/main.go +++ b/cmd/protofmt/main.go @@ -23,4 +23,5 @@ func main() { log.Fatalln("protofmt failed", err) } proto.NewFormatter(os.Stdout, " ").Format(def) + //spew.Dump(def) } diff --git a/cmd/protofmt/unformatted.proto b/cmd/protofmt/unformatted.proto index 9b34dbc..089ceb9 100644 --- a/cmd/protofmt/unformatted.proto +++ b/cmd/protofmt/unformatted.proto @@ -61,4 +61,11 @@ service SearchService { // comment enum Enum {} service Service {} message Message {} -oneof Oneof {} \ No newline at end of file + +// context aware +enum enum { + enum = 0; +} +message message { + message message = 1; +} \ No newline at end of file diff --git a/enum.go b/enum.go index 0e46571..28df780 100644 --- a/enum.go +++ b/enum.go @@ -1,9 +1,6 @@ package proto -import ( - "fmt" - "strconv" -) +import "strconv" // Enum definition consists of a name and an enum body. type Enum struct { @@ -41,16 +38,16 @@ func (f EnumField) columns() (cols []aligned) { func (f *EnumField) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() if tok != tIDENT { - return fmt.Errorf("found %q, expected identifier", lit) + return p.unexpected(lit, "enum field identifier", f) } f.Name = lit tok, lit = p.scanIgnoreWhitespace() if tok != tEQUALS { - return p.unexpected(lit, "=") + return p.unexpected(lit, "enum field =", f) } i, err := p.s.scanInteger() if err != nil { - return fmt.Errorf("found %q, expected integer", err) + return p.unexpected(lit, "enum field integer", f) } f.Integer = i tok, lit = p.scanIgnoreWhitespace() @@ -64,21 +61,21 @@ func (f *EnumField) parse(p *Parser) error { f.ValueOption = o tok, lit = p.scanIgnoreWhitespace() if tok != tRIGHTSQUARE { - return fmt.Errorf("found %q, expected ]", lit) + return p.unexpected(lit, "option closing ]", f) } } return nil } func (e *Enum) parse(p *Parser) error { - tok, lit := p.scanIgnoreWhitespace() + tok, lit := p.s.scanIdent() if tok != tIDENT { - return fmt.Errorf("found %q, expected identifier", lit) + return p.unexpected(lit, "enum identifier", e) } e.Name = lit tok, lit = p.scanIgnoreWhitespace() if tok != tLEFTCURLY { - return fmt.Errorf("found %q, expected {", lit) + return p.unexpected(lit, "enum opening {", e) } for { tok, lit = p.scanIgnoreWhitespace() @@ -107,7 +104,7 @@ func (e *Enum) parse(p *Parser) error { } done: if tok != tRIGHTCURLY { - return fmt.Errorf("found %q, expected }", lit) + return p.unexpected(lit, "enum closing }", e) } return nil } diff --git a/field.go b/field.go index 24ee492..d58de5d 100644 --- a/field.go +++ b/field.go @@ -73,16 +73,16 @@ done: func parseFieldAfterType(f *Field, p *Parser) error { tok, lit := p.scanIgnoreWhitespace() if tok != tIDENT { - return p.unexpected(lit, "identifier") + return p.unexpected(lit, "field identifier", f) } f.Name = lit tok, lit = p.scanIgnoreWhitespace() if tok != tEQUALS { - return p.unexpected(lit, "=") + return p.unexpected(lit, "field =", f) } i, err := p.s.scanInteger() if err != nil { - return p.unexpected(lit, "sequence number") + return p.unexpected(lit, "field sequence number", f) } f.Sequence = i // see if there are options @@ -106,7 +106,7 @@ func parseFieldAfterType(f *Field, p *Parser) error { break } if tCOMMA != tok { - return p.unexpected(lit, ",") + return p.unexpected(lit, "option ,", o) } } return nil @@ -132,25 +132,25 @@ func (f *MapField) Accept(v Visitor) { func (f *MapField) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() if tLESS != tok { - return p.unexpected(lit, "<") + return p.unexpected(lit, "map keyType <", f) } tok, lit = p.scanIgnoreWhitespace() if tIDENT != tok { - return p.unexpected(lit, "identifier") + return p.unexpected(lit, "map identifier", f) } f.KeyType = lit tok, lit = p.scanIgnoreWhitespace() if tCOMMA != tok { - return p.unexpected(lit, ",") + return p.unexpected(lit, "map type separator ,", f) } tok, lit = p.scanIgnoreWhitespace() if tIDENT != tok { - return p.unexpected(lit, "identifier") + return p.unexpected(lit, "map valueType identifier", f) } f.Type = lit tok, lit = p.scanIgnoreWhitespace() if tGREATER != tok { - return p.unexpected(lit, ">") + return p.unexpected(lit, "mak valueType >", f) } return parseFieldAfterType(f.Field, p) } diff --git a/formatter_utils.go b/formatter_utils.go index 5fe1ffc..290d4b3 100644 --- a/formatter_utils.go +++ b/formatter_utils.go @@ -90,22 +90,30 @@ func (f *Formatter) printAsGroups(list []Visitee) { group := []columnsPrintable{} lastGroupName := "" for i := 0; i < len(list); i++ { - groupName := nameOfVisitee(list[i]) - printable, isColumnsPrintable := list[i].(columnsPrintable) + each := list[i] + groupName := nameOfVisitee(each) + printable, isColumnsPrintable := each.(columnsPrintable) if isColumnsPrintable { - if lastGroupName == groupName { - // collect in group - group = append(group, printable) - } else { + if lastGroupName != groupName { // print current group + if len(group) > 0 { + f.printListOfColumns(group) + lastGroupName = groupName + // begin new group + group = []columnsPrintable{} + } + } + group = append(group, printable) + } else { + // not printable in group + // print current group + if len(group) > 0 { f.printListOfColumns(group) lastGroupName = groupName // begin new group - group = []columnsPrintable{printable} + group = []columnsPrintable{} } - } else { - // not printable in group - list[i].Accept(f) + each.Accept(f) } } // print last group diff --git a/import.go b/import.go index d5c60f6..a98c2c7 100644 --- a/import.go +++ b/import.go @@ -22,8 +22,10 @@ func (i *Import) parse(p *Parser) error { return i.parse(p) case tQUOTE: i.Filename = p.s.scanUntil('"') + case tSINGLEQUOTE: + i.Filename = p.s.scanUntil('\'') default: - return p.unexpected(lit, "weak|public|quoted identifier") + return p.unexpected(lit, "import classifier weak|public|quoted", i) } return nil } diff --git a/message.go b/message.go index 4db62d6..80c6fce 100644 --- a/message.go +++ b/message.go @@ -14,12 +14,12 @@ func (m *Message) Accept(v Visitor) { func (m *Message) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() if tok != tIDENT { - return p.unexpected(lit, "identifier") + return p.unexpected(lit, "message identifier", m) } m.Name = lit tok, lit = p.scanIgnoreWhitespace() if tok != tLEFTCURLY { - return p.unexpected(lit, "{") + return p.unexpected(lit, "message opening {", m) } for { tok, lit = p.scanIgnoreWhitespace() @@ -78,7 +78,7 @@ func (m *Message) parse(p *Parser) error { } done: if tok != tRIGHTCURLY { - return p.unexpected(lit, "}") + return p.unexpected(lit, "message closing }", m) } return nil } diff --git a/oneof.go b/oneof.go index 53a0d9a..f41cb30 100644 --- a/oneof.go +++ b/oneof.go @@ -60,17 +60,17 @@ func (o *OneOfField) Accept(v Visitor) { func (o *OneOfField) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() if tok != tIDENT { - return p.unexpected(lit, "identifier") + return p.unexpected(lit, "oneof field identifier", o) } o.Name = lit tok, lit = p.scanIgnoreWhitespace() if tok != tEQUALS { - return p.unexpected(lit, "=") + return p.unexpected(lit, "oneof field =", o) } _, lit = p.scanIgnoreWhitespace() i, err := strconv.Atoi(lit) if err != nil { - return p.unexpected(lit, "sequence number") + return p.unexpected(lit, "oneof sequence number", o) } o.Sequence = i tok, lit = p.scanIgnoreWhitespace() diff --git a/option.go b/option.go index 6d8f5c0..f4abbba 100644 --- a/option.go +++ b/option.go @@ -44,28 +44,28 @@ func (o *Option) parse(p *Parser) error { case tLEFTPAREN: tok, lit = p.scanIgnoreWhitespace() if tok != tIDENT { - return p.unexpected(lit, "identifier") + return p.unexpected(lit, "option identifier", o) } o.Name = lit tok, lit = p.scanIgnoreWhitespace() if tok != tRIGHTPAREN { - return p.unexpected(lit, ")") + return p.unexpected(lit, "option closing )", o) } default: - return p.unexpected(lit, "identifier or (") + return p.unexpected(lit, "option identifier or (", o) } tok, lit = p.scanIgnoreWhitespace() if tok == tDOT { // extend identifier tok, lit = p.scanIgnoreWhitespace() if tok != tIDENT { - return p.unexpected(lit, "postfix identifier") + return p.unexpected(lit, "option postfix identifier", o) } o.Name = fmt.Sprintf("%s.%s", o.Name, lit) tok, lit = p.scanIgnoreWhitespace() } if tok != tEQUALS { - return p.unexpected(lit, "=") + return p.unexpected(lit, "option constant =", o) } l := new(Literal) if err := l.parse(p); err != nil { @@ -95,7 +95,7 @@ func (l *Literal) parse(p *Parser) error { if tok == tQUOTE { ident := p.s.scanUntil('"') if len(ident) == 0 { - return p.unexpected(lit, "quoted string") + return p.unexpected(lit, "literal quoted string", l) } l.Source, l.IsString = ident, true return nil @@ -104,7 +104,7 @@ func (l *Literal) parse(p *Parser) error { if tok == tSINGLEQUOTE { ident := p.s.scanUntil('\'') if len(ident) == 0 { - return p.unexpected(lit, "single quoted string") + return p.unexpected(lit, "literal single quoted string", l) } l.Source, l.IsString = ident, true return nil diff --git a/parser.go b/parser.go index 69f3d77..fd31485 100644 --- a/parser.go +++ b/parser.go @@ -55,6 +55,12 @@ func (p *Parser) scanIgnoreWhitespace() (tok token, lit string) { return } +// scanIdent scans all whitespaces and scans the next non-whitespace identifier (not a keyword). +func (p *Parser) scanIdent() (tok token, lit string) { + p.s.skipWhitespace() + return p.s.scanIdent() +} + // unscan pushes the previously read token back onto the buffer. func (p *Parser) unscan() { p.buf.n = 1 } @@ -63,11 +69,11 @@ func (p *Parser) newComment(lit string) *Comment { return &Comment{Message: lit} } -func (p *Parser) unexpected(found, expected string) error { +func (p *Parser) unexpected(found, expected string, obj interface{}) error { debug := "" if p.debug { _, file, line, _ := runtime.Caller(1) - debug = fmt.Sprintf(" at %s:%d", file, line) + debug = fmt.Sprintf(" at %s:%d (with %#v)", file, line, obj) } return fmt.Errorf("found %q on line %d, expected %s%s", found, p.s.line, expected, debug) } @@ -79,16 +85,16 @@ func (p *Parser) scanStringLiteral() (string, error) { if tok == tQUOTE { s := p.s.scanUntil('"') if len(s) == 0 { - return "", p.unexpected(lit, "quoted string") + return "", p.unexpected(lit, "quoted string", nil) } return s, nil } if tok == tSINGLEQUOTE { s := p.s.scanUntil('\'') if len(s) == 0 { - return "", p.unexpected(lit, "single quoted string") + return "", p.unexpected(lit, "single quoted string", nil) } return s, nil } - return "", p.unexpected(lit, "single or double quoted string") + return "", p.unexpected(lit, "single or double quoted string", nil) } diff --git a/parser_test.go b/parser_test.go index e5b1b51..dea458a 100644 --- a/parser_test.go +++ b/parser_test.go @@ -26,3 +26,14 @@ func newParserOn(def string) *Parser { p.debug = true return p } + +func TestScanIdent(t *testing.T) { + p := NewParser(strings.NewReader(" message ")) + tok, lit := p.scanIdent() + if got, want := tok, tIDENT; got != want { + t.Errorf("got [%v] want [%v]", got, want) + } + if got, want := lit, "message"; got != want { + t.Errorf("got [%v] want [%v]", got, want) + } +} diff --git a/proto.go b/proto.go index 9984da5..555a191 100644 --- a/proto.go +++ b/proto.go @@ -1,9 +1,6 @@ package proto -import ( - "log" - "strings" -) +import "strings" // Proto represents a .proto definition type Proto struct { @@ -76,11 +73,10 @@ func (proto *Proto) parse(p *Parser) error { } proto.Elements = append(proto.Elements, msg) case tSEMICOLON: - default: - if p.debug { - log.Println("unhandled (1=EOF)", lit, tok) - } + case tEOF: goto done + default: + return p.unexpected(lit, "comment|option|import|syntax|enum|service|package|message", p) } } done: diff --git a/scanner.go b/scanner.go index 27c1d96..f55fda8 100644 --- a/scanner.go +++ b/scanner.go @@ -27,14 +27,13 @@ func (s *scanner) scan() (tok token, lit string) { // If we see whitespace then consume all contiguous whitespace. // If we see a letter then consume as an ident or reserved word. - // If we see a digit then consume as a number. // If we see a slash then consume all as a comment (can be multiline) if isWhitespace(ch) { s.unread(ch) return s.scanWhitespace() } else if isLetter(ch) { s.unread(ch) - return s.scanIdent() + return s.scanKeyword() } // Otherwise read the individual character. @@ -97,6 +96,19 @@ func (s *scanner) scanWhitespace() (tok token, lit string) { return tWS, buf.String() } +// skipWhitespace consumes all contiguous whitespace. +func (s *scanner) skipWhitespace() { + // Non-whitespace characters and EOF will cause the loop to exit. + for { + if ch := s.read(); ch == eof { + break + } else if !isWhitespace(ch) { + s.unread(ch) + break + } + } +} + func (s *scanner) scanInteger() (int, error) { var i int if _, err := fmt.Fscanf(s.r, "%d", &i); err != nil { @@ -111,6 +123,27 @@ func (s *scanner) scanIdent() (tok token, lit string) { var buf bytes.Buffer buf.WriteRune(s.read()) + // Read every subsequent ident character into the buffer. + // Non-ident characters and EOF will cause the loop to exit. + for { + if ch := s.read(); ch == eof { + break + } else if !isLetter(ch) && !isDigit(ch) && ch != '_' && ch != '.' { // underscore and dot can be part of identifier + s.unread(ch) + break + } else { + _, _ = buf.WriteRune(ch) + } + } + return tIDENT, buf.String() +} + +// scanKeyword consumes the current rune and all contiguous ident runes. +func (s *scanner) scanKeyword() (tok token, lit string) { + // Create a buffer and read the current character into it. + var buf bytes.Buffer + buf.WriteRune(s.read()) + // Read every subsequent ident character into the buffer. // Non-ident characters and EOF will cause the loop to exit. for { diff --git a/service.go b/service.go index c55c285..e7b4578 100644 --- a/service.go +++ b/service.go @@ -15,12 +15,12 @@ func (s *Service) Accept(v Visitor) { func (s *Service) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() if tok != tIDENT { - return p.unexpected(lit, "identifier") + return p.unexpected(lit, "service identifier", s) } s.Name = lit tok, lit = p.scanIgnoreWhitespace() if tok != tLEFTCURLY { - return p.unexpected(lit, "{") + return p.unexpected(lit, "service opening {", s) } for { tok, lit = p.scanIgnoreWhitespace() @@ -38,7 +38,7 @@ func (s *Service) parse(p *Parser) error { case tRIGHTCURLY: goto done default: - return p.unexpected(lit, "comment|rpc|;}") + return p.unexpected(lit, "service comment|rpc", s) } } done: @@ -65,21 +65,23 @@ func (r *RPC) columns() (cols []aligned) { leftAligned("rpc "), leftAligned(r.Name), leftAligned(" (")) - stream := "" if r.StreamsRequest { - stream = "stream " + cols = append(cols, leftAligned("stream ")) + } else { + cols = append(cols, alignedEmpty) } cols = append(cols, - leftAligned(stream+r.RequestType), + leftAligned(r.RequestType), leftAligned(") "), leftAligned("returns"), leftAligned(" (")) - stream = "" if r.StreamsReturns { - stream = "stream " + cols = append(cols, leftAligned("stream ")) + } else { + cols = append(cols, alignedEmpty) } cols = append(cols, - leftAligned(stream+r.ReturnsType), + leftAligned(r.ReturnsType), leftAligned(")")) return cols } @@ -88,12 +90,12 @@ func (r *RPC) columns() (cols []aligned) { func (r *RPC) parse(p *Parser) error { tok, lit := p.scanIgnoreWhitespace() if tok != tIDENT { - return p.unexpected(lit, "method") + return p.unexpected(lit, "rpc method", r) } r.Name = lit tok, lit = p.scanIgnoreWhitespace() if tok != tLEFTPAREN { - return p.unexpected(lit, "(") + return p.unexpected(lit, "rpc type opening (", r) } tok, lit = p.scanIgnoreWhitespace() if iSTREAM == lit { @@ -101,20 +103,20 @@ func (r *RPC) parse(p *Parser) error { tok, lit = p.scanIgnoreWhitespace() } if tok != tIDENT { - return p.unexpected(lit, "stream | request type") + return p.unexpected(lit, "rpc stream | request type", r) } r.RequestType = lit tok, lit = p.scanIgnoreWhitespace() if tok != tRIGHTPAREN { - return p.unexpected(lit, ")") + return p.unexpected(lit, "rpc type closing )", r) } tok, lit = p.scanIgnoreWhitespace() if tok != tRETURNS { - return p.unexpected(lit, "returns") + return p.unexpected(lit, "rpc returns", r) } tok, lit = p.scanIgnoreWhitespace() if tok != tLEFTPAREN { - return p.unexpected(lit, "(") + return p.unexpected(lit, "rpc type opening (", r) } tok, lit = p.scanIgnoreWhitespace() if iSTREAM == lit { @@ -122,12 +124,12 @@ func (r *RPC) parse(p *Parser) error { tok, lit = p.scanIgnoreWhitespace() } if tok != tIDENT { - return p.unexpected(lit, "stream | returns type") + return p.unexpected(lit, "rpc stream | returns type", r) } r.ReturnsType = lit tok, lit = p.scanIgnoreWhitespace() if tok != tRIGHTPAREN { - return p.unexpected(lit, ")") + return p.unexpected(lit, "rpc type closing )", r) } return nil } diff --git a/syntax.go b/syntax.go index 6eca510..981187a 100644 --- a/syntax.go +++ b/syntax.go @@ -12,17 +12,11 @@ func (s *Syntax) Accept(v Visitor) { func (s *Syntax) parse(p *Parser) error { if tok, lit := p.scanIgnoreWhitespace(); tok != tEQUALS { - return p.unexpected(lit, "=") + return p.unexpected(lit, "syntax =", s) } - if tok, lit := p.scanIgnoreWhitespace(); tok != tQUOTE && tok != tSINGLEQUOTE { - return p.unexpected(lit, "\" or '") - } - tok, lit := p.scanIgnoreWhitespace() - if tok != tIDENT { - return p.unexpected(lit, "proto") - } - if tok, lit := p.scanIgnoreWhitespace(); tok != tQUOTE && tok != tSINGLEQUOTE { - return p.unexpected(lit, "\" or '") + lit, err := p.scanStringLiteral() + if err != nil { + return err } s.Value = lit return nil