diff --git a/go/vt/sqlparser/comments.go b/go/vt/sqlparser/comments.go index 4168c667387..b348a664645 100644 --- a/go/vt/sqlparser/comments.go +++ b/go/vt/sqlparser/comments.go @@ -151,7 +151,11 @@ func hasCommentPrefix(sql string) bool { func ExtractMysqlComment(sql string) (version string, innerSQL string) { sql = sql[3 : len(sql)-2] - endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool { return !unicode.IsDigit(c) }) + digitCount := 0 + endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool { + digitCount++ + return !unicode.IsDigit(c) || digitCount == 6 + }) version = sql[0:endOfVersionIndex] innerSQL = strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace) diff --git a/go/vt/sqlparser/comments_test.go b/go/vt/sqlparser/comments_test.go index fdee23f92bf..aa9ceffc858 100644 --- a/go/vt/sqlparser/comments_test.go +++ b/go/vt/sqlparser/comments_test.go @@ -189,6 +189,10 @@ func TestExtractMysqlComment(t *testing.T) { input: "/*!50708 SET max_execution_time=5000*/", outSQL: "SET max_execution_time=5000", outVersion: "50708", + }, { + input: "/*!50708* from*/", + outSQL: "* from", + outVersion: "50708", }, { input: "/*! SET max_execution_time=5000*/", outSQL: "SET max_execution_time=5000", diff --git a/go/vt/sqlparser/parse_test.go b/go/vt/sqlparser/parse_test.go index 434bbea555b..458ecac2726 100644 --- a/go/vt/sqlparser/parse_test.go +++ b/go/vt/sqlparser/parse_test.go @@ -549,7 +549,20 @@ var ( }, { input: "select /* string in case statement */ if(max(case a when 'foo' then 1 else 0 end) = 1, 'foo', 'bar') as foobar from t", }, { - input: "select /* dual */ 1 from dual", + input: "/*!show databases*/", + output: "show databases", + }, { + input: "select /*!40101 * from*/ t", + output: "select * from t", + }, { + input: "select /*! * from*/ t", + output: "select * from t", + }, { + input: "select /*!* from*/ t", + output: "select * from t", + }, { + input: "select /*!401011 from*/ t", + output: "select 1 from t", }, { input: "select /* dual */ 1 from dual", }, { diff --git a/go/vt/sqlparser/token.go b/go/vt/sqlparser/token.go index 77ea928b4b6..40165c9a77b 100644 --- a/go/vt/sqlparser/token.go +++ b/go/vt/sqlparser/token.go @@ -34,18 +34,19 @@ const ( // Tokenizer is the struct used to generate SQL // tokens for the parser. type Tokenizer struct { - InStream io.Reader - AllowComments bool - ForceEOF bool - lastChar uint16 - Position int - lastToken []byte - LastError error - posVarIndex int - ParseTree Statement - partialDDL *DDL - nesting int - multi bool + InStream io.Reader + AllowComments bool + ForceEOF bool + lastChar uint16 + Position int + lastToken []byte + LastError error + posVarIndex int + ParseTree Statement + partialDDL *DDL + nesting int + multi bool + specialComment *Tokenizer buf []byte bufPos int @@ -427,6 +428,18 @@ func (tkn *Tokenizer) Error(err string) { // Scan scans the tokenizer for the next token and returns // the token type and an optional value. func (tkn *Tokenizer) Scan() (int, []byte) { + if tkn.specialComment != nil { + // Enter specialComment scan mode. + // for scanning such kind of comment: /*! MySQL-specific code */ + specialComment := tkn.specialComment + tok, val := specialComment.Scan() + if tok != 0 { + // return the specialComment scan result as the result + return tok, val + } + // leave specialComment scan mode after all stream consumed. + tkn.specialComment = nil + } if tkn.lastChar == 0 { tkn.next() } @@ -495,7 +508,12 @@ func (tkn *Tokenizer) Scan() (int, []byte) { return tkn.scanCommentType1("//") case '*': tkn.next() - return tkn.scanCommentType2() + switch tkn.lastChar { + case '!': + return tkn.scanMySQLSpecificComment() + default: + return tkn.scanCommentType2() + } default: return int(ch), nil } @@ -818,6 +836,29 @@ func (tkn *Tokenizer) scanCommentType2() (int, []byte) { return COMMENT, buffer.Bytes() } +func (tkn *Tokenizer) scanMySQLSpecificComment() (int, []byte) { + buffer := &bytes2.Buffer{} + buffer.WriteString("/*!") + tkn.next() + for { + if tkn.lastChar == '*' { + tkn.consumeNext(buffer) + if tkn.lastChar == '/' { + tkn.consumeNext(buffer) + break + } + continue + } + if tkn.lastChar == eofChar { + return LEX_ERROR, buffer.Bytes() + } + tkn.consumeNext(buffer) + } + _, sql := ExtractMysqlComment(buffer.String()) + tkn.specialComment = NewStringTokenizer(sql) + return tkn.Scan() +} + func (tkn *Tokenizer) consumeNext(buffer *bytes2.Buffer) { if tkn.lastChar == eofChar { // This should never happen. @@ -853,6 +894,7 @@ func (tkn *Tokenizer) next() { func (tkn *Tokenizer) reset() { tkn.ParseTree = nil tkn.partialDDL = nil + tkn.specialComment = nil tkn.posVarIndex = 0 tkn.nesting = 0 tkn.ForceEOF = false