diff --git a/charset/charset.go b/charset/charset.go index f9fa39263..4ed3eb4d0 100644 --- a/charset/charset.go +++ b/charset/charset.go @@ -107,6 +107,16 @@ func ValidCharsetAndCollation(cs string, co string) bool { return ok } +// GetDefaultCollationLegacy is compatible with the charset support in old version parser. +func GetDefaultCollationLegacy(charset string) (string, error) { + switch strings.ToLower(charset) { + case CharsetUTF8, CharsetUTF8MB4, CharsetASCII, CharsetLatin1, CharsetBin: + return GetDefaultCollation(charset) + default: + return "", errors.Errorf("Unknown charset %s", charset) + } +} + // GetDefaultCollation returns the default collation for charset. func GetDefaultCollation(charset string) (string, error) { cs, err := GetCharsetInfo(charset) diff --git a/charset/encoding.go b/charset/encoding.go new file mode 100644 index 000000000..d31a09759 --- /dev/null +++ b/charset/encoding.go @@ -0,0 +1,137 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package charset + +import ( + "strings" + + "golang.org/x/text/encoding" + "golang.org/x/text/transform" +) + +const ( + encodingBufferSizeDefault = 1024 + encodingBufferSizeRecycleThreshold = 4 * 1024 + + encodingDefault = "utf-8" +) + +type EncodingLabel string + +// Format trim and change the label to lowercase. +func Format(label string) EncodingLabel { + return EncodingLabel(strings.ToLower(strings.Trim(label, "\t\n\r\f "))) +} + +// Formatted is used when the label is already trimmed and it is lowercase. +func Formatted(label string) EncodingLabel { + return EncodingLabel(label) +} + +// Encoding provide a interface to encode/decode a string with specific encoding. +type Encoding struct { + enc encoding.Encoding + name string + charLength func([]byte) int + buffer []byte +} + +// Enabled indicates whether the non-utf8 encoding is used. +func (e *Encoding) Enabled() bool { + return e.enc != nil && e.charLength != nil +} + +// Name returns the name of the current encoding. +func (e *Encoding) Name() string { + return e.name +} + +// NewEncoding creates a new Encoding. +func NewEncoding(label EncodingLabel) *Encoding { + if len(label) == 0 { + return &Encoding{} + } + e, name := lookup(label) + if e != nil && name != encodingDefault { + return &Encoding{ + enc: e, + name: name, + charLength: FindNextCharacterLength(name), + buffer: make([]byte, encodingBufferSizeDefault), + } + } + return &Encoding{name: name} +} + +// UpdateEncoding updates to a new Encoding without changing the buffer. +func (e *Encoding) UpdateEncoding(label EncodingLabel) { + enc, name := lookup(label) + e.name = name + if enc != nil && name != encodingDefault { + e.enc = enc + } + if len(e.buffer) == 0 { + e.buffer = make([]byte, encodingBufferSizeDefault) + } +} + +// Encode encodes the bytes to a string. +func (e *Encoding) Encode(src []byte) (string, bool) { + return e.transform(e.enc.NewEncoder(), src) +} + +// Decode decodes the bytes to a string. +func (e *Encoding) Decode(src []byte) (string, bool) { + return e.transform(e.enc.NewDecoder(), src) +} + +func (e *Encoding) transform(transformer transform.Transformer, src []byte) (string, bool) { + if len(e.buffer) < len(src) { + e.buffer = make([]byte, len(src)*2) + } + var destOffset, srcOffset int + ok := true + for { + nextLen := 4 + if e.charLength != nil { + nextLen = e.charLength(src[srcOffset:]) + } + srcEnd := srcOffset + nextLen + if srcEnd > len(src) { + srcEnd = len(src) + } + nDest, nSrc, err := transformer.Transform(e.buffer[destOffset:], src[srcOffset:srcEnd], false) + destOffset += nDest + srcOffset += nSrc + if err == nil { + if srcOffset >= len(src) { + result := string(e.buffer[:destOffset]) + if len(e.buffer) > encodingBufferSizeRecycleThreshold { + // This prevents Encoding from holding too much memory. + e.buffer = make([]byte, encodingBufferSizeDefault) + } + return result, ok + } + } else if err == transform.ErrShortDst { + newDest := make([]byte, len(e.buffer)*2) + copy(newDest, e.buffer) + e.buffer = newDest + } else { + e.buffer[destOffset] = byte('?') + destOffset += 1 + srcOffset += 1 + ok = false + } + } +} diff --git a/charset/encoding_table.go b/charset/encoding_table.go index 37a5550b7..d9e48b9ed 100644 --- a/charset/encoding_table.go +++ b/charset/encoding_table.go @@ -31,7 +31,11 @@ import ( // leading and trailing whitespace. func Lookup(label string) (e encoding.Encoding, name string) { label = strings.ToLower(strings.Trim(label, "\t\n\r\f ")) - enc := encodings[label] + return lookup(Formatted(label)) +} + +func lookup(label EncodingLabel) (e encoding.Encoding, name string) { + enc := encodings[string(label)] return enc.e, enc.name } @@ -258,3 +262,32 @@ var encodings = map[string]struct { "utf-16le": {unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM), "utf-16le"}, "x-user-defined": {charmap.XUserDefined, "x-user-defined"}, } + +// FindNextCharacterLength is used in lexer.peek() to determine the next character length. +func FindNextCharacterLength(label string) func([]byte) int { + if f, ok := encodingNextCharacterLength[label]; ok { + return f + } + return nil +} + +var encodingNextCharacterLength = map[string]func([]byte) int{ + // https://en.wikipedia.org/wiki/GBK_(character_encoding)#Layout_diagram + "gbk": func(bs []byte) int { + if len(bs) == 0 || bs[0] < 0x80 { + // A byte in the range 00–7F is a single byte that means the same thing as it does in ASCII. + return 1 + } + return 2 + }, + "utf-8": func(bs []byte) int { + if len(bs) == 0 || bs[0] < 0x80 { + return 1 + } else if bs[0] < 0xe0 { + return 2 + } else if bs[0] < 0xf0 { + return 3 + } + return 4 + }, +} diff --git a/hintparserimpl.go b/hintparserimpl.go index 129a0050b..98c28c071 100644 --- a/hintparserimpl.go +++ b/hintparserimpl.go @@ -129,11 +129,11 @@ func (hp *hintParser) parse(input string, sqlMode mysql.SQLMode, initPos Pos) ([ hp.result = nil hp.lexer.reset(input[3:]) hp.lexer.SetSQLMode(sqlMode) - hp.lexer.r.p = Pos{ + hp.lexer.r.updatePos(Pos{ Line: initPos.Line, Col: initPos.Col + 3, // skipped the initial '/*+' Offset: 0, - } + }) hp.lexer.inBangComment = true // skip the final '*/' (we need the '*/' for reporting warnings) yyhintParse(&hp.lexer, hp) diff --git a/lexer.go b/lexer.go index ecdda4be3..284145844 100644 --- a/lexer.go +++ b/lexer.go @@ -21,6 +21,8 @@ import ( "unicode" "unicode/utf8" + "github.com/pingcap/errors" + "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" tidbfeature "github.com/pingcap/parser/tidb" ) @@ -39,6 +41,8 @@ type Scanner struct { r reader buf bytes.Buffer + encoding charset.Encoding + errs []error warns []error stmtStartPos int @@ -134,11 +138,28 @@ func (s *Scanner) AppendError(err error) { s.errs = append(s.errs, err) } +func (s *Scanner) tryDecodeToUTF8String(sql string) string { + if !s.encoding.Enabled() { + name := s.encoding.Name() + if len(name) > 0 { + s.AppendError(errors.Errorf("Encoding %s is not supported", name)) + s.lastErrorAsWarn() + } + return sql + } + utf8Lit, ok := s.encoding.Decode(Slice(sql)) + if !ok { + s.AppendError(errors.Errorf("Cannot convert string '%x' from %s to utf8mb4", sql, s.encoding.Name())) + s.lastErrorAsWarn() + } + return utf8Lit +} + func (s *Scanner) getNextToken() int { r := s.r tok, pos, lit := s.scan() if tok == identifier { - tok = handleIdent(&yySymType{}) + tok = s.handleIdent(&yySymType{}) } if tok == identifier { if tok1 := s.isTokenIdentifier(lit, pos.Offset); tok1 != 0 { @@ -163,7 +184,7 @@ func (s *Scanner) Lex(v *yySymType) int { v.offset = pos.Offset v.ident = lit if tok == identifier { - tok = handleIdent(v) + tok = s.handleIdent(v) } if tok == identifier { if tok1 := s.isTokenIdentifier(lit, pos.Offset); tok1 != 0 { @@ -240,6 +261,7 @@ func (s *Scanner) EnableWindowFunc(val bool) { func (s *Scanner) InheritScanner(sql string) *Scanner { return &Scanner{ r: reader{s: sql}, + encoding: s.encoding, sqlMode: s.sqlMode, supportWindowFunc: s.supportWindowFunc, } @@ -250,6 +272,22 @@ func NewScanner(s string) *Scanner { return &Scanner{r: reader{s: s}} } +func (s *Scanner) handleIdent(lval *yySymType) int { + str := lval.ident + // A character string literal may have an optional character set introducer and COLLATE clause: + // [_charset_name]'string' [COLLATE collation_name] + // See https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html + if !strings.HasPrefix(str, "_") { + return identifier + } + cs, err := charset.GetCharsetInfo(str[1:]) + if err != nil { + return identifier + } + lval.ident = cs.Name + return underscoreCS +} + func (s *Scanner) skipWhitespace() rune { return s.r.incAsLongAs(unicode.IsSpace) } @@ -266,7 +304,7 @@ func (s *Scanner) scan() (tok int, pos Pos, lit string) { return 0, pos, "" } - if !s.r.eof() && isIdentExtend(ch0) { + if isIdentExtend(ch0) { return scanIdentifier(s) } @@ -302,7 +340,7 @@ func startWithXx(s *Scanner) (tok int, pos Pos, lit string) { } return } - s.r.p = pos + s.r.updatePos(pos) return scanIdentifier(s) } @@ -334,7 +372,7 @@ func startWithBb(s *Scanner) (tok int, pos Pos, lit string) { } return } - s.r.p = pos + s.r.updatePos(pos) return scanIdentifier(s) } @@ -762,7 +800,7 @@ func (s *Scanner) scanBit() { } func (s *Scanner) scanFloat(beg *Pos) (tok int, pos Pos, lit string) { - s.r.p = *beg + s.r.updatePos(*beg) // float = D1 . D2 e D3 s.scanDigits() ch0 := s.r.peek() @@ -784,7 +822,7 @@ func (s *Scanner) scanFloat(beg *Pos) (tok int, pos Pos, lit string) { // D1 . D2 e XX when XX is not D3, parse the result to an identifier. // 9e9e = 9e9(float) + e(identifier) // 9est = 9est(identifier) - s.r.p = *beg + s.r.updatePos(*beg) s.r.incAsLongAs(isIdentChar) tok = identifier } @@ -810,7 +848,7 @@ func (s *Scanner) scanVersionDigits(min, max int) { if isDigit(ch) { s.r.inc() } else if i < min { - s.r.p = pos + s.r.updatePos(pos) return } else { break @@ -832,7 +870,7 @@ func (s *Scanner) scanFeatureIDs() (featureIDs []string) { state = expectChar break } - s.r.p = pos + s.r.updatePos(pos) return nil case expectChar: if isIdentChar(ch) { @@ -840,7 +878,7 @@ func (s *Scanner) scanFeatureIDs() (featureIDs []string) { state = obtainChar break } - s.r.p = pos + s.r.updatePos(pos) return nil case obtainChar: if isIdentChar(ch) { @@ -856,11 +894,11 @@ func (s *Scanner) scanFeatureIDs() (featureIDs []string) { featureIDs = append(featureIDs, b.String()) return featureIDs } - s.r.p = pos + s.r.updatePos(pos) return nil } } - s.r.p = pos + s.r.updatePos(pos) return nil } @@ -876,6 +914,9 @@ type reader struct { s string p Pos w int + + peekRune rune + peekRuneUpdated bool } var eof = Pos{-1, -1, -1} @@ -888,21 +929,22 @@ func (r *reader) eof() bool { // if reader meets EOF, it will return unicode.ReplacementChar. to distinguish from // the real unicode.ReplacementChar, the caller should call r.eof() again to check. func (r *reader) peek() rune { + if r.peekRuneUpdated { + return r.peekRune + } if r.eof() { return unicode.ReplacementChar } v, w := rune(r.s[r.p.Offset]), 1 - switch { - case v == 0: - r.w = w - return v // illegal UTF-8 encoding - case v >= 0x80: + if v >= 0x80 { v, w = utf8.DecodeRuneInString(r.s[r.p.Offset:]) if v == utf8.RuneError && w == 1 { - v = rune(r.s[r.p.Offset]) // illegal UTF-8 encoding + v = rune(r.s[r.p.Offset]) // illegal encoding } } r.w = w + r.peekRune = v + r.peekRuneUpdated = true return v } @@ -915,6 +957,7 @@ func (r *reader) inc() { } r.p.Offset += r.w r.p.Col++ + r.peekRuneUpdated = false } func (r *reader) incN(n int) { @@ -936,6 +979,13 @@ func (r *reader) pos() Pos { return r.p } +func (r *reader) updatePos(pos Pos) { + if r.p.Offset != pos.Offset { + r.peekRuneUpdated = false + } + r.p = pos +} + func (r *reader) data(from *Pos) string { return r.s[from.Offset:r.p.Offset] } diff --git a/misc.go b/misc.go index abf66cf26..078772123 100644 --- a/misc.go +++ b/misc.go @@ -14,9 +14,8 @@ package parser import ( - "strings" - - "github.com/pingcap/parser/charset" + "reflect" + "unsafe" ) func isLetter(ch rune) bool { @@ -991,18 +990,13 @@ func (s *Scanner) isTokenIdentifier(lit string, offset int) int { return tok } -func handleIdent(lval *yySymType) int { - s := lval.ident - // A character string literal may have an optional character set introducer and COLLATE clause: - // [_charset_name]'string' [COLLATE collation_name] - // See https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html - if !strings.HasPrefix(s, "_") { - return identifier - } - cs, err := charset.GetCharsetInfo(s[1:]) - if err != nil { - return identifier - } - lval.ident = cs.Name - return underscoreCS +// Slice converts string to slice without copy. +// Use at your own risk. +func Slice(s string) (b []byte) { + pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + pString := (*reflect.StringHeader)(unsafe.Pointer(&s)) + pBytes.Data = pString.Data + pBytes.Len = pString.Len + pBytes.Cap = pString.Len + return } diff --git a/parser.go b/parser.go index ca4ebb674..61a2151a6 100644 --- a/parser.go +++ b/parser.go @@ -14908,7 +14908,7 @@ yynewstate: case 1166: { // See https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html - co, err := charset.GetDefaultCollation(yyS[yypt-1].ident) + co, err := charset.GetDefaultCollationLegacy(yyS[yypt-1].ident) if err != nil { yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", yyS[yypt-1].ident)) return 1 @@ -14932,7 +14932,7 @@ yynewstate: } case 1169: { - co, err := charset.GetDefaultCollation(yyS[yypt-1].ident) + co, err := charset.GetDefaultCollationLegacy(yyS[yypt-1].ident) if err != nil { yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", yyS[yypt-1].ident)) return 1 @@ -14948,7 +14948,7 @@ yynewstate: } case 1170: { - co, err := charset.GetDefaultCollation(yyS[yypt-1].ident) + co, err := charset.GetDefaultCollationLegacy(yyS[yypt-1].ident) if err != nil { yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", yyS[yypt-1].ident)) return 1 diff --git a/parser.y b/parser.y index 8287c9036..77e550a93 100644 --- a/parser.y +++ b/parser.y @@ -6452,7 +6452,7 @@ Literal: | "UNDERSCORE_CHARSET" stringLit { // See https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html - co, err := charset.GetDefaultCollation($1) + co, err := charset.GetDefaultCollationLegacy($1) if err != nil { yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", $1)) return 1 @@ -6476,7 +6476,7 @@ Literal: } | "UNDERSCORE_CHARSET" hexLit { - co, err := charset.GetDefaultCollation($1) + co, err := charset.GetDefaultCollationLegacy($1) if err != nil { yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", $1)) return 1 @@ -6492,7 +6492,7 @@ Literal: } | "UNDERSCORE_CHARSET" bitLit { - co, err := charset.GetDefaultCollation($1) + co, err := charset.GetDefaultCollationLegacy($1) if err != nil { yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", $1)) return 1 diff --git a/parser_test.go b/parser_test.go index adb374c81..b30515edb 100644 --- a/parser_test.go +++ b/parser_test.go @@ -6320,3 +6320,66 @@ func (s *testParserSuite) TestPlanRecreator(c *C) { c.Assert(v.Stmt.Text(), Equals, "SELECT a FROM t") c.Assert(v.Analyze, IsTrue) } + +func (s *testParserSuite) TestGBKEncoding(c *C) { + p := parser.New() + gbkEncoding, _ := charset.Lookup("gbk") + encoder := gbkEncoding.NewEncoder() + sql, err := encoder.String("create table 测试表 (测试列 varchar(255) default 'GBK测试用例');") + c.Assert(err, IsNil) + + stmt, err := p.ParseOneStmt(sql, "", "") + c.Assert(err, IsNil) + checker := &gbkEncodingChecker{} + _, _ = stmt.Accept(checker) + c.Assert(checker.tblName, Not(Equals), "测试表") + c.Assert(checker.colName, Not(Equals), "测试列") + + p.SetParserConfig(parser.ParserConfig{CharsetClient: "gbk"}) + stmt, err = p.ParseOneStmt(sql, "", "") + c.Assert(err, IsNil) + _, _ = stmt.Accept(checker) + c.Assert(checker.tblName, Equals, "测试表") + c.Assert(checker.colName, Equals, "测试列") + c.Assert(checker.expr, Equals, "GBK测试用例") + + utf8SQL := "select '芢' from `玚`;" + sql, err = encoder.String(utf8SQL) + c.Assert(err, IsNil) + stmt, err = p.ParseOneStmt(sql, "", "") + c.Assert(err, IsNil) + stmt, err = p.ParseOneStmt("select '\xc6\x5c' from `\xab\x60`;", "", "") + c.Assert(err, IsNil) + + p.SetParserConfig(parser.ParserConfig{CharsetClient: ""}) + stmt, err = p.ParseOneStmt("select _gbk '\xc6\x5c' from dual;", "", "") + c.Assert(err, NotNil) +} + +type gbkEncodingChecker struct { + tblName string + colName string + expr string +} + +func (g *gbkEncodingChecker) Enter(n ast.Node) (node ast.Node, skipChildren bool) { + if tn, ok := n.(*ast.TableName); ok { + g.tblName = tn.Name.O + return n, false + } + if cn, ok := n.(*ast.ColumnName); ok { + g.colName = cn.Name.O + return n, false + } + if c, ok := n.(*ast.ColumnOption); ok { + if ve, ok := c.Expr.(ast.ValueExpr); ok { + g.expr = ve.GetString() + return n, false + } + } + return n, false +} + +func (g *gbkEncodingChecker) Leave(n ast.Node) (node ast.Node, ok bool) { + return n, true +} diff --git a/yy_parser.go b/yy_parser.go index ff53d6c39..8a364c8d8 100644 --- a/yy_parser.go +++ b/yy_parser.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/auth" + "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" ) @@ -72,6 +73,7 @@ type ParserConfig struct { EnableWindowFunction bool EnableStrictDoubleTypeCheck bool SkipPositionRecording bool + CharsetClient string // CharsetClient indicates how to decode the original SQL. } // Parser represents a parser instance. Some temporary objects are stored in it to reduce object allocation during Parse function. @@ -132,11 +134,13 @@ func (parser *Parser) SetParserConfig(config ParserConfig) { parser.EnableWindowFunc(config.EnableWindowFunction) parser.SetStrictDoubleTypeCheck(config.EnableStrictDoubleTypeCheck) parser.lexer.skipPositionRecording = config.SkipPositionRecording + parser.lexer.encoding = *charset.NewEncoding(charset.Format(config.CharsetClient)) } // Parse parses a query string to raw ast.StmtNode. // If charset or collation is "", default charset and collation will be used. func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode, warns []error, err error) { + sql = parser.lexer.tryDecodeToUTF8String(sql) if charset == "" { charset = mysql.DefaultCharset }