diff --git a/go/vt/sqlparser/parse_test.go b/go/vt/sqlparser/parse_test.go index 8ae27a85943..5932984558a 100644 --- a/go/vt/sqlparser/parse_test.go +++ b/go/vt/sqlparser/parse_test.go @@ -1854,6 +1854,45 @@ func TestConvert(t *testing.T) { } } +func TestPositionedErr(t *testing.T) { + invalidSQL := []struct { + input string + output PositionedErr + }{{ + input: "select convert('abc' as date) from t", + output: PositionedErr{"syntax error at position 24 near 'as'", 24, []byte("as")}, + }, { + input: "select convert from t", + output: PositionedErr{"syntax error at position 20 near 'from'", 20, []byte("from")}, + }, { + input: "select cast('foo', decimal) from t", + output: PositionedErr{"syntax error at position 19", 19, nil}, + }, { + input: "select convert('abc', datetime(4+9)) from t", + output: PositionedErr{"syntax error at position 34", 34, nil}, + }, { + input: "select convert('abc', decimal(4+9)) from t", + output: PositionedErr{"syntax error at position 33", 33, nil}, + }, { + input: "set transaction isolation level 12345", + output: PositionedErr{"syntax error at position 38 near '12345'", 38, []byte("12345")}, + }, { + input: "select * from a left join b", + output: PositionedErr{"syntax error at position 28", 28, nil}, + }} + + for _, tcase := range invalidSQL { + tkn := NewStringTokenizer(tcase.input) + _, err := ParseNext(tkn) + + if posErr, ok := err.(PositionedErr); !ok { + t.Errorf("%s: %v expected PositionedErr, got (%T) %v", tcase.input, err, err, tcase.output) + } else if posErr.Pos != tcase.output.Pos || !bytes.Equal(posErr.Near, tcase.output.Near) || err.Error() != tcase.output.Err { + t.Errorf("%s: %v, want: %v", tcase.input, err, tcase.output) + } + } +} + func TestSubStr(t *testing.T) { validSQL := []struct { diff --git a/go/vt/sqlparser/token.go b/go/vt/sqlparser/token.go index f7fd4850f30..f80e2c245b9 100644 --- a/go/vt/sqlparser/token.go +++ b/go/vt/sqlparser/token.go @@ -18,7 +18,6 @@ package sqlparser import ( "bytes" - "errors" "fmt" "io" @@ -452,15 +451,23 @@ func (tkn *Tokenizer) Lex(lval *yySymType) int { return typ } +// PositionedErr holds context related to parser errors +type PositionedErr struct { + Err string + Pos int + Near []byte +} + +func (p PositionedErr) Error() string { + if p.Near != nil { + return fmt.Sprintf("%s at position %v near '%s'", p.Err, p.Pos, p.Near) + } + return fmt.Sprintf("%s at position %v", p.Err, p.Pos) +} + // Error is called by go yacc if there's a parsing error. func (tkn *Tokenizer) Error(err string) { - buf := &bytes2.Buffer{} - if tkn.lastToken != nil { - fmt.Fprintf(buf, "%s at position %v near '%s'", err, tkn.Position, tkn.lastToken) - } else { - fmt.Fprintf(buf, "%s at position %v", err, tkn.Position) - } - tkn.LastError = errors.New(buf.String()) + tkn.LastError = PositionedErr{Err: err, Pos: tkn.Position, Near: tkn.lastToken} // Try and re-sync to the next statement tkn.skipStatement()