diff --git a/Makefile b/Makefile index e61cf711397..acf5363a00b 100644 --- a/Makefile +++ b/Makefile @@ -17,6 +17,7 @@ MAKEFLAGS = -s export GOBIN=$(PWD)/bin export GO111MODULE=on export GODEBUG=tls13=0 +export REWRITER=go/vt/sqlparser/rewriter.go # Disabled parallel processing of target prerequisites to avoid that integration tests are racing each other (e.g. for ports) and may fail. # Since we are not using this Makefile for compilation, limiting parallelism will not increase build time. @@ -87,6 +88,11 @@ install: build parser: make -C go/vt/sqlparser +visitor: + go build -o visitorgen go/visitorgen/main/main.go + ./visitorgen -input=go/vt/sqlparser/ast.go -output=$(REWRITER) + rm ./visitorgen + # To pass extra flags, run test.go manually. # For example: go run test.go -docker=false -- --extra-flag # For more info see: go run test.go -help @@ -100,9 +106,10 @@ clean: go clean -i ./go/... rm -rf third_party/acolyte rm -rf go/vt/.proto.tmp + rm -rf ./visitorgen # Remove everything including stuff pulled down by bootstrap.sh -cleanall: +cleanall: clean # directories created by bootstrap.sh # - exclude vtdataroot and vthook as they may have data we want rm -rf bin dist lib pkg diff --git a/go/visitorgen/ast_walker.go b/go/visitorgen/ast_walker.go new file mode 100644 index 00000000000..822fb6c4c5e --- /dev/null +++ b/go/visitorgen/ast_walker.go @@ -0,0 +1,130 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package visitorgen + +import ( + "go/ast" + "reflect" +) + +var _ ast.Visitor = (*walker)(nil) + +type walker struct { + result SourceFile +} + +// Walk walks the given AST and translates it to the simplified AST used by the next steps +func Walk(node ast.Node) *SourceFile { + var w walker + ast.Walk(&w, node) + return &w.result +} + +// Visit implements the ast.Visitor interface +func (w *walker) Visit(node ast.Node) ast.Visitor { + switch n := node.(type) { + case *ast.TypeSpec: + switch t2 := n.Type.(type) { + case *ast.InterfaceType: + w.append(&InterfaceDeclaration{ + name: n.Name.Name, + block: "", + }) + case *ast.StructType: + var fields []*Field + for _, f := range t2.Fields.List { + for _, name := range f.Names { + fields = append(fields, &Field{ + name: name.Name, + typ: sastType(f.Type), + }) + } + + } + w.append(&StructDeclaration{ + name: n.Name.Name, + fields: fields, + }) + case *ast.ArrayType: + w.append(&TypeAlias{ + name: n.Name.Name, + typ: &Array{inner: sastType(t2.Elt)}, + }) + case *ast.Ident: + w.append(&TypeAlias{ + name: n.Name.Name, + typ: &TypeString{t2.Name}, + }) + + default: + panic(reflect.TypeOf(t2)) + } + case *ast.FuncDecl: + if len(n.Recv.List) > 1 || len(n.Recv.List[0].Names) > 1 { + panic("don't know what to do!") + } + var f *Field + if len(n.Recv.List) == 1 { + r := n.Recv.List[0] + t := sastType(r.Type) + if len(r.Names) > 1 { + panic("don't know what to do!") + } + if len(r.Names) == 1 { + f = &Field{ + name: r.Names[0].Name, + typ: t, + } + } else { + f = &Field{ + name: "", + typ: t, + } + } + } + + w.append(&FuncDeclaration{ + receiver: f, + name: n.Name.Name, + block: "", + arguments: nil, + }) + } + + return w +} + +func (w *walker) append(line Sast) { + w.result.lines = append(w.result.lines, line) +} + +func sastType(e ast.Expr) Type { + switch n := e.(type) { + case *ast.StarExpr: + return &Ref{sastType(n.X)} + case *ast.Ident: + return &TypeString{n.Name} + case *ast.ArrayType: + return &Array{inner: sastType(n.Elt)} + case *ast.InterfaceType: + return &TypeString{"interface{}"} + case *ast.StructType: + return &TypeString{"struct{}"} + } + + panic(reflect.TypeOf(e)) +} diff --git a/go/visitorgen/ast_walker_test.go b/go/visitorgen/ast_walker_test.go new file mode 100644 index 00000000000..a4b01f70835 --- /dev/null +++ b/go/visitorgen/ast_walker_test.go @@ -0,0 +1,239 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package visitorgen + +import ( + "go/parser" + "go/token" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stretchr/testify/require" +) + +func TestSingleInterface(t *testing.T) { + input := ` +package sqlparser + +type Nodeiface interface { + iNode() +} +` + + fset := token.NewFileSet() + ast, err := parser.ParseFile(fset, "ast.go", input, 0) + require.NoError(t, err) + + result := Walk(ast) + expected := SourceFile{ + lines: []Sast{&InterfaceDeclaration{ + name: "Nodeiface", + block: "", + }}, + } + assert.Equal(t, expected.String(), result.String()) +} + +func TestEmptyStruct(t *testing.T) { + input := ` +package sqlparser + +type Empty struct {} +` + + fset := token.NewFileSet() + ast, err := parser.ParseFile(fset, "ast.go", input, 0) + require.NoError(t, err) + + result := Walk(ast) + expected := SourceFile{ + lines: []Sast{&StructDeclaration{ + name: "Empty", + fields: []*Field{}, + }}, + } + assert.Equal(t, expected.String(), result.String()) +} + +func TestStructWithStringField(t *testing.T) { + input := ` +package sqlparser + +type Struct struct { + field string +} +` + + fset := token.NewFileSet() + ast, err := parser.ParseFile(fset, "ast.go", input, 0) + require.NoError(t, err) + + result := Walk(ast) + expected := SourceFile{ + lines: []Sast{&StructDeclaration{ + name: "Struct", + fields: []*Field{{ + name: "field", + typ: &TypeString{typName: "string"}, + }}, + }}, + } + assert.Equal(t, expected.String(), result.String()) +} + +func TestStructWithDifferentTypes(t *testing.T) { + input := ` +package sqlparser + +type Struct struct { + field string + reference *string + array []string + arrayOfRef []*string +} +` + + fset := token.NewFileSet() + ast, err := parser.ParseFile(fset, "ast.go", input, 0) + require.NoError(t, err) + + result := Walk(ast) + expected := SourceFile{ + lines: []Sast{&StructDeclaration{ + name: "Struct", + fields: []*Field{{ + name: "field", + typ: &TypeString{typName: "string"}, + }, { + name: "reference", + typ: &Ref{&TypeString{typName: "string"}}, + }, { + name: "array", + typ: &Array{&TypeString{typName: "string"}}, + }, { + name: "arrayOfRef", + typ: &Array{&Ref{&TypeString{typName: "string"}}}, + }}, + }}, + } + assert.Equal(t, expected.String(), result.String()) +} + +func TestStructWithTwoStringFieldInOneLine(t *testing.T) { + input := ` +package sqlparser + +type Struct struct { + left, right string +} +` + + fset := token.NewFileSet() + ast, err := parser.ParseFile(fset, "ast.go", input, 0) + require.NoError(t, err) + + result := Walk(ast) + expected := SourceFile{ + lines: []Sast{&StructDeclaration{ + name: "Struct", + fields: []*Field{{ + name: "left", + typ: &TypeString{typName: "string"}, + }, { + name: "right", + typ: &TypeString{typName: "string"}, + }}, + }}, + } + assert.Equal(t, expected.String(), result.String()) +} + +func TestStructWithSingleMethod(t *testing.T) { + input := ` +package sqlparser + +type Empty struct {} + +func (*Empty) method() {} +` + + fset := token.NewFileSet() + ast, err := parser.ParseFile(fset, "ast.go", input, 0) + require.NoError(t, err) + + result := Walk(ast) + expected := SourceFile{ + lines: []Sast{ + &StructDeclaration{ + name: "Empty", + fields: []*Field{}}, + &FuncDeclaration{ + receiver: &Field{ + name: "", + typ: &Ref{&TypeString{"Empty"}}, + }, + name: "method", + block: "", + arguments: []*Field{}, + }, + }, + } + assert.Equal(t, expected.String(), result.String()) +} + +func TestSingleArrayType(t *testing.T) { + input := ` +package sqlparser + +type Strings []string +` + + fset := token.NewFileSet() + ast, err := parser.ParseFile(fset, "ast.go", input, 0) + require.NoError(t, err) + + result := Walk(ast) + expected := SourceFile{ + lines: []Sast{&TypeAlias{ + name: "Strings", + typ: &Array{&TypeString{"string"}}, + }}, + } + assert.Equal(t, expected.String(), result.String()) +} + +func TestSingleTypeAlias(t *testing.T) { + input := ` +package sqlparser + +type String string +` + + fset := token.NewFileSet() + ast, err := parser.ParseFile(fset, "ast.go", input, 0) + require.NoError(t, err) + + result := Walk(ast) + expected := SourceFile{ + lines: []Sast{&TypeAlias{ + name: "String", + typ: &TypeString{"string"}, + }}, + } + assert.Equal(t, expected.String(), result.String()) +} diff --git a/go/visitorgen/main/main.go b/go/visitorgen/main/main.go new file mode 100644 index 00000000000..ae09ad7efdf --- /dev/null +++ b/go/visitorgen/main/main.go @@ -0,0 +1,166 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "flag" + "fmt" + "go/parser" + "go/token" + "os" + + "vitess.io/vitess/go/visitorgen" +) + +var ( + inputFile = flag.String("input", "", "input file to use") + outputFile = flag.String("output", "", "output file") +) + +const usage = `Usage of visitorgen: + +go run go/visitorgen/main/main.go -input=/path/to/ast.go -output=/path/to/rewriter.go +` + +func main() { + flag.Usage = func() { + os.Stderr.WriteString(usage) + os.Stderr.WriteString("\nOptions:\n") + flag.PrintDefaults() + + } + flag.Parse() + + if *inputFile == "" || *outputFile == "" { + fmt.Println("> " + *inputFile) + fmt.Println("> " + *outputFile) + panic("need input and output file") + } + + fs := token.NewFileSet() + file, err := parser.ParseFile(fs, *inputFile, nil, parser.DeclarationErrors) + if err != nil { + panic(err) + } + + astWalkResult := visitorgen.Walk(file) + vp := visitorgen.Transform(astWalkResult) + vd := visitorgen.ToVisitorPlan(vp) + + replacementMethods := visitorgen.EmitReplacementMethods(vd) + typeSwitch := visitorgen.EmitTypeSwitches(vd) + + fw := newFileWriter(*outputFile) + defer fw.Close() + + fw.writeln(fileHeader) + fw.writeln(replacementMethods) + fw.write(applyHeader) + fw.writeln(typeSwitch) + fw.writeln(fileFooter) +} + +type fileWriter struct { + file *os.File +} + +func newFileWriter(file string) *fileWriter { + f, err := os.Create(file) + if err != nil { + panic(err) + } + return &fileWriter{file: f} +} + +func (fw *fileWriter) writeln(s string) { + fw.write(s) + fw.write("\n") +} + +func (fw *fileWriter) write(s string) { + _, err := fw.file.WriteString(s) + if err != nil { + panic(err) + } +} + +func (fw *fileWriter) Close() { + fw.file.Close() +} + +const fileHeader = `// Code generated by visitorgen/main/main.go. DO NOT EDIT. + +package sqlparser + +//go:generate make visitor + +import ( + "reflect" +) + +type replacerFunc func(newNode, parent SQLNode) + +// application carries all the shared data so we can pass it around cheaply. +type application struct { + pre, post ApplyFunc + cursor Cursor +} +` + +const applyHeader = ` +// apply is where the visiting happens. Here is where we keep the big switch-case that will be used +// to do the actual visiting of SQLNodes +func (a *application) apply(parent, node SQLNode, replacer replacerFunc) { + if node == nil || isNilValue(node) { + return + } + + // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead + saved := a.cursor + a.cursor.replacer = replacer + a.cursor.node = node + a.cursor.parent = parent + + if a.pre != nil && !a.pre(&a.cursor) { + a.cursor = saved + return + } + + // walk children + // (the order of the cases is alphabetical) + switch n := node.(type) { + case nil: + ` + +const fileFooter = ` + default: + panic("unknown ast type " + reflect.TypeOf(node).String()) + } + + if a.post != nil && !a.post(&a.cursor) { + panic(abort) + } + + a.cursor = saved +} + +func isNilValue(i interface{}) bool { + valueOf := reflect.ValueOf(i) + kind := valueOf.Kind() + isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice + return isNullable && valueOf.IsNil() +}` diff --git a/go/visitorgen/sast.go b/go/visitorgen/sast.go new file mode 100644 index 00000000000..e46485e8f5d --- /dev/null +++ b/go/visitorgen/sast.go @@ -0,0 +1,178 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package visitorgen + +// simplified ast - when reading the golang ast of the ast.go file, we translate the golang ast objects +// to this much simpler format, that contains only the necessary information and no more +type ( + // SourceFile contains all important lines from an ast.go file + SourceFile struct { + lines []Sast + } + + // Sast or simplified AST, is a representation of the ast.go lines we are interested in + Sast interface { + toSastString() string + } + + // InterfaceDeclaration represents a declaration of an interface. This is used to keep track of which types + // need to be handled by the visitor framework + InterfaceDeclaration struct { + name, block string + } + + // TypeAlias is used whenever we see a `type XXX YYY` - XXX is the new name for YYY. + // Note that YYY could be an array or a reference + TypeAlias struct { + name string + typ Type + } + + // FuncDeclaration represents a function declaration. These are tracked to know which types implement interfaces. + FuncDeclaration struct { + receiver *Field + name, block string + arguments []*Field + } + + // StructDeclaration represents a struct. It contains the fields and their types + StructDeclaration struct { + name string + fields []*Field + } + + // Field is a field in a struct - a name with a type tuple + Field struct { + name string + typ Type + } + + // Type represents a type in the golang type system. Used to keep track of type we need to handle, + // and the types of fields. + Type interface { + toTypString() string + rawTypeName() string + } + + // TypeString is a raw type name, such as `string` + TypeString struct { + typName string + } + + // Ref is a reference to something, such as `*string` + Ref struct { + inner Type + } + + // Array is an array of things, such as `[]string` + Array struct { + inner Type + } +) + +var _ Sast = (*InterfaceDeclaration)(nil) +var _ Sast = (*StructDeclaration)(nil) +var _ Sast = (*FuncDeclaration)(nil) +var _ Sast = (*TypeAlias)(nil) + +var _ Type = (*TypeString)(nil) +var _ Type = (*Ref)(nil) +var _ Type = (*Array)(nil) + +// String returns a textual representation of the SourceFile. This is for testing purposed +func (t *SourceFile) String() string { + var result string + for _, l := range t.lines { + result += l.toSastString() + result += "\n" + } + + return result +} + +func (t *Ref) toTypString() string { + return "*" + t.inner.toTypString() +} + +func (t *Array) toTypString() string { + return "[]" + t.inner.toTypString() +} + +func (t *TypeString) toTypString() string { + return t.typName +} + +func (f *FuncDeclaration) toSastString() string { + var receiver string + if f.receiver != nil { + receiver = "(" + f.receiver.String() + ") " + } + var args string + for i, arg := range f.arguments { + if i > 0 { + args += ", " + } + args += arg.String() + } + + return "func " + receiver + f.name + "(" + args + ") {" + blockInNewLines(f.block) + "}" +} + +func (i *InterfaceDeclaration) toSastString() string { + return "type " + i.name + " interface {" + blockInNewLines(i.block) + "}" +} + +func (a *TypeAlias) toSastString() string { + return "type " + a.name + " " + a.typ.toTypString() +} + +func (s *StructDeclaration) toSastString() string { + var block string + for _, f := range s.fields { + block += "\t" + f.String() + "\n" + } + + return "type " + s.name + " struct {" + blockInNewLines(block) + "}" +} + +func blockInNewLines(block string) string { + if block == "" { + return "" + } + return "\n" + block + "\n" +} + +// String returns a string representation of a field +func (f *Field) String() string { + if f.name != "" { + return f.name + " " + f.typ.toTypString() + } + + return f.typ.toTypString() +} + +func (t *TypeString) rawTypeName() string { + return t.typName +} + +func (t *Ref) rawTypeName() string { + return t.inner.rawTypeName() +} + +func (t *Array) rawTypeName() string { + return t.inner.rawTypeName() +} diff --git a/go/visitorgen/struct_producer.go b/go/visitorgen/struct_producer.go new file mode 100644 index 00000000000..1c293f30803 --- /dev/null +++ b/go/visitorgen/struct_producer.go @@ -0,0 +1,253 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package visitorgen + +import ( + "fmt" + "sort" +) + +// VisitorData is the data needed to produce the output file +type ( + // VisitorItem represents something that needs to be added to the rewriter infrastructure + VisitorItem interface { + toFieldItemString() string + typeName() string + asSwitchCase() string + asReplMethod() string + getFieldName() string + } + + // SingleFieldItem is a single field in a struct + SingleFieldItem struct { + StructType, FieldType Type + FieldName string + } + + // ArrayFieldItem is an array field in a struct + ArrayFieldItem struct { + StructType, ItemType Type + FieldName string + } + + // ArrayItem is an array that implements SQLNode + ArrayItem struct { + StructType, ItemType Type + } + + // VisitorPlan represents all the output needed for the rewriter + VisitorPlan struct { + Switches []*SwitchCase // The cases for the big switch statement used to implement the visitor + } + + // SwitchCase is what we need to know to produce all the type switch cases in the visitor. + SwitchCase struct { + Type Type + Fields []VisitorItem + } +) + +var _ VisitorItem = (*SingleFieldItem)(nil) +var _ VisitorItem = (*ArrayItem)(nil) +var _ VisitorItem = (*ArrayFieldItem)(nil) +var _ sort.Interface = (*VisitorPlan)(nil) +var _ sort.Interface = (*SwitchCase)(nil) + +// ToVisitorPlan transforms the source information into a plan for the visitor code that needs to be produced +func ToVisitorPlan(input *SourceInformation) *VisitorPlan { + var output VisitorPlan + + for _, typ := range input.interestingTypes { + switchit := &SwitchCase{Type: typ} + stroct, isStruct := input.structs[typ.rawTypeName()] + if isStruct { + for _, f := range stroct.fields { + switchit.Fields = append(switchit.Fields, trySingleItem(input, f, typ)...) + } + } else { + itemType := input.getItemTypeOfArray(typ) + if itemType != nil && input.isSQLNode(itemType) { + switchit.Fields = append(switchit.Fields, &ArrayItem{ + StructType: typ, + ItemType: itemType, + }) + } + } + sort.Sort(switchit) + output.Switches = append(output.Switches, switchit) + } + sort.Sort(&output) + return &output +} + +func trySingleItem(input *SourceInformation, f *Field, typ Type) []VisitorItem { + if input.isSQLNode(f.typ) { + return []VisitorItem{&SingleFieldItem{ + StructType: typ, + FieldType: f.typ, + FieldName: f.name, + }} + } + + arrType, isArray := f.typ.(*Array) + if isArray && input.isSQLNode(arrType.inner) { + return []VisitorItem{&ArrayFieldItem{ + StructType: typ, + ItemType: arrType.inner, + FieldName: f.name, + }} + } + return []VisitorItem{} +} + +// String returns a string, used for testing +func (v *VisitorPlan) String() string { + var sb builder + for _, s := range v.Switches { + sb.appendF("Type: %v", s.Type.toTypString()) + for _, f := range s.Fields { + sb.appendF("\t%v", f.toFieldItemString()) + } + } + return sb.String() +} + +func (s *SingleFieldItem) toFieldItemString() string { + return fmt.Sprintf("single item: %v of type: %v", s.FieldName, s.FieldType.toTypString()) +} + +func (s *SingleFieldItem) asSwitchCase() string { + return fmt.Sprintf(` a.apply(node, n.%s, %s)`, s.FieldName, s.typeName()) +} + +func (s *SingleFieldItem) asReplMethod() string { + _, isRef := s.StructType.(*Ref) + + if isRef { + return fmt.Sprintf(`func %s(newNode, parent SQLNode) { + parent.(%s).%s = newNode.(%s) +}`, s.typeName(), s.StructType.toTypString(), s.FieldName, s.FieldType.toTypString()) + } + + return fmt.Sprintf(`func %s(newNode, parent SQLNode) { + tmp := parent.(%s) + tmp.%s = newNode.(%s) +}`, s.typeName(), s.StructType.toTypString(), s.FieldName, s.FieldType.toTypString()) + +} + +func (ai *ArrayItem) asReplMethod() string { + name := ai.typeName() + return fmt.Sprintf(`type %s int + +func (r *%s) replace(newNode, container SQLNode) { + container.(%s)[int(*r)] = newNode.(%s) +} + +func (r *%s) inc() { + *r++ +}`, name, name, ai.StructType.toTypString(), ai.ItemType.toTypString(), name) +} + +func (afi *ArrayFieldItem) asReplMethod() string { + name := afi.typeName() + return fmt.Sprintf(`type %s int + +func (r *%s) replace(newNode, container SQLNode) { + container.(%s).%s[int(*r)] = newNode.(%s) +} + +func (r *%s) inc() { + *r++ +}`, name, name, afi.StructType.toTypString(), afi.FieldName, afi.ItemType.toTypString(), name) +} + +func (s *SingleFieldItem) getFieldName() string { + return s.FieldName +} + +func (s *SingleFieldItem) typeName() string { + return "replace" + s.StructType.rawTypeName() + s.FieldName +} + +func (afi *ArrayFieldItem) toFieldItemString() string { + return fmt.Sprintf("array field item: %v.%v contains items of type %v", afi.StructType.toTypString(), afi.FieldName, afi.ItemType.toTypString()) +} + +func (ai *ArrayItem) toFieldItemString() string { + return fmt.Sprintf("array item: %v containing items of type %v", ai.StructType.toTypString(), ai.ItemType.toTypString()) +} + +func (ai *ArrayItem) getFieldName() string { + panic("Should not be called!") +} + +func (afi *ArrayFieldItem) getFieldName() string { + return afi.FieldName +} + +func (ai *ArrayItem) asSwitchCase() string { + return fmt.Sprintf(` replacer := %s(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + }`, ai.typeName()) +} + +func (afi *ArrayFieldItem) asSwitchCase() string { + return fmt.Sprintf(` replacer%s := %s(0) + replacer%sB := &replacer%s + for _, item := range n.%s { + a.apply(node, item, replacer%sB.replace) + replacer%sB.inc() + }`, afi.FieldName, afi.typeName(), afi.FieldName, afi.FieldName, afi.FieldName, afi.FieldName, afi.FieldName) +} + +func (ai *ArrayItem) typeName() string { + return "replace" + ai.StructType.rawTypeName() + "Items" +} + +func (afi *ArrayFieldItem) typeName() string { + return "replace" + afi.StructType.rawTypeName() + afi.FieldName +} +func (v *VisitorPlan) Len() int { + return len(v.Switches) +} + +func (v *VisitorPlan) Less(i, j int) bool { + return v.Switches[i].Type.rawTypeName() < v.Switches[j].Type.rawTypeName() +} + +func (v *VisitorPlan) Swap(i, j int) { + temp := v.Switches[i] + v.Switches[i] = v.Switches[j] + v.Switches[j] = temp +} +func (s *SwitchCase) Len() int { + return len(s.Fields) +} + +func (s *SwitchCase) Less(i, j int) bool { + return s.Fields[i].getFieldName() < s.Fields[j].getFieldName() +} + +func (s *SwitchCase) Swap(i, j int) { + temp := s.Fields[i] + s.Fields[i] = s.Fields[j] + s.Fields[j] = temp +} diff --git a/go/visitorgen/struct_producer_test.go b/go/visitorgen/struct_producer_test.go new file mode 100644 index 00000000000..065b532a9eb --- /dev/null +++ b/go/visitorgen/struct_producer_test.go @@ -0,0 +1,423 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package visitorgen + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEmptyStructVisitor(t *testing.T) { + /* + type Node interface{} + type Struct struct {} + func (*Struct) iNode() {} + */ + + input := &SourceInformation{ + interestingTypes: map[string]Type{ + "*Struct": &Ref{&TypeString{"Struct"}}, + }, + interfaces: map[string]bool{ + "Node": true, + }, + structs: map[string]*StructDeclaration{ + "Struct": {name: "Struct", fields: []*Field{}}, + }, + typeAliases: map[string]*TypeAlias{}, + } + + result := ToVisitorPlan(input) + + expected := &VisitorPlan{ + Switches: []*SwitchCase{{ + Type: &Ref{&TypeString{"Struct"}}, + Fields: []VisitorItem{}, + }}, + } + + assert.Equal(t, expected.String(), result.String()) +} + +func TestStructWithSqlNodeField(t *testing.T) { + /* + type Node interface{} + type Struct struct { + Field Node + } + func (*Struct) iNode() {} + */ + input := &SourceInformation{ + interestingTypes: map[string]Type{ + "*Struct": &Ref{&TypeString{"Struct"}}, + }, + interfaces: map[string]bool{ + "Node": true, + }, + structs: map[string]*StructDeclaration{ + "Struct": {name: "Struct", fields: []*Field{ + {name: "Field", typ: &TypeString{"Node"}}, + }}, + }, + typeAliases: map[string]*TypeAlias{}, + } + + result := ToVisitorPlan(input) + + expected := &VisitorPlan{ + Switches: []*SwitchCase{{ + Type: &Ref{&TypeString{"Struct"}}, + Fields: []VisitorItem{&SingleFieldItem{ + StructType: &Ref{&TypeString{"Struct"}}, + FieldType: &TypeString{"Node"}, + FieldName: "Field", + }}, + }}, + } + + assert.Equal(t, expected.String(), result.String()) +} + +func TestStructWithStringField2(t *testing.T) { + /* + type Node interface{} + type Struct struct { + Field Node + } + func (*Struct) iNode() {} + */ + + input := &SourceInformation{ + interestingTypes: map[string]Type{ + "*Struct": &Ref{&TypeString{"Struct"}}, + }, + interfaces: map[string]bool{ + "Node": true, + }, + structs: map[string]*StructDeclaration{ + "Struct": {name: "Struct", fields: []*Field{ + {name: "Field", typ: &TypeString{"string"}}, + }}, + }, + typeAliases: map[string]*TypeAlias{}, + } + + result := ToVisitorPlan(input) + + expected := &VisitorPlan{ + Switches: []*SwitchCase{{ + Type: &Ref{&TypeString{"Struct"}}, + Fields: []VisitorItem{}, + }}, + } + + assert.Equal(t, expected.String(), result.String()) +} + +func TestArrayAsSqlNode(t *testing.T) { + /* + type NodeInterface interface { + iNode() + } + + func (*NodeArray) iNode{} + + type NodeArray []NodeInterface + */ + + input := &SourceInformation{ + interfaces: map[string]bool{"NodeInterface": true}, + interestingTypes: map[string]Type{ + "*NodeArray": &Ref{&TypeString{"NodeArray"}}}, + structs: map[string]*StructDeclaration{}, + typeAliases: map[string]*TypeAlias{ + "NodeArray": { + name: "NodeArray", + typ: &Array{&TypeString{"NodeInterface"}}, + }, + }, + } + + result := ToVisitorPlan(input) + + expected := &VisitorPlan{ + Switches: []*SwitchCase{{ + Type: &Ref{&TypeString{"NodeArray"}}, + Fields: []VisitorItem{&ArrayItem{ + StructType: &Ref{&TypeString{"NodeArray"}}, + ItemType: &TypeString{"NodeInterface"}, + }}, + }}, + } + + assert.Equal(t, expected.String(), result.String()) +} + +func TestStructWithStructField(t *testing.T) { + /* + type Node interface{} + type Struct struct { + Field *Struct + } + func (*Struct) iNode() {} + */ + + input := &SourceInformation{ + interestingTypes: map[string]Type{ + "*Struct": &Ref{&TypeString{"Struct"}}}, + structs: map[string]*StructDeclaration{ + "Struct": {name: "Struct", fields: []*Field{ + {name: "Field", typ: &Ref{&TypeString{"Struct"}}}, + }}, + }, + typeAliases: map[string]*TypeAlias{}, + } + + result := ToVisitorPlan(input) + + expected := &VisitorPlan{ + Switches: []*SwitchCase{{ + Type: &Ref{&TypeString{"Struct"}}, + Fields: []VisitorItem{&SingleFieldItem{ + StructType: &Ref{&TypeString{"Struct"}}, + FieldType: &Ref{&TypeString{"Struct"}}, + FieldName: "Field", + }}, + }}, + } + + assert.Equal(t, expected.String(), result.String()) +} + +func TestStructWithArrayOfNodes(t *testing.T) { + /* + type NodeInterface interface {} + type Struct struct { + Items []NodeInterface + } + + func (*Struct) iNode{} + */ + + input := &SourceInformation{ + interfaces: map[string]bool{ + "NodeInterface": true, + }, + interestingTypes: map[string]Type{ + "*Struct": &Ref{&TypeString{"Struct"}}}, + structs: map[string]*StructDeclaration{ + "Struct": {name: "Struct", fields: []*Field{ + {name: "Items", typ: &Array{&TypeString{"NodeInterface"}}}, + }}, + }, + typeAliases: map[string]*TypeAlias{}, + } + + result := ToVisitorPlan(input) + + expected := &VisitorPlan{ + Switches: []*SwitchCase{{ + Type: &Ref{&TypeString{"Struct"}}, + Fields: []VisitorItem{&ArrayFieldItem{ + StructType: &Ref{&TypeString{"Struct"}}, + ItemType: &TypeString{"NodeInterface"}, + FieldName: "Items", + }}, + }}, + } + + assert.Equal(t, expected.String(), result.String()) +} + +func TestStructWithArrayOfStrings(t *testing.T) { + /* + type NodeInterface interface {} + type Struct struct { + Items []string + } + + func (*Struct) iNode{} + */ + + input := &SourceInformation{ + interfaces: map[string]bool{ + "NodeInterface": true, + }, + interestingTypes: map[string]Type{ + "*Struct": &Ref{&TypeString{"Struct"}}}, + structs: map[string]*StructDeclaration{ + "Struct": {name: "Struct", fields: []*Field{ + {name: "Items", typ: &Array{&TypeString{"string"}}}, + }}, + }, + typeAliases: map[string]*TypeAlias{}, + } + + result := ToVisitorPlan(input) + + expected := &VisitorPlan{ + Switches: []*SwitchCase{{ + Type: &Ref{&TypeString{"Struct"}}, + Fields: []VisitorItem{}, + }}, + } + + assert.Equal(t, expected.String(), result.String()) +} + +func TestArrayOfStringsThatImplementSQLNode(t *testing.T) { + /* + type NodeInterface interface {} + type Struct []string + func (Struct) iNode{} + */ + + input := &SourceInformation{ + interfaces: map[string]bool{"NodeInterface": true}, + interestingTypes: map[string]Type{"Struct": &Ref{&TypeString{"Struct"}}}, + structs: map[string]*StructDeclaration{}, + typeAliases: map[string]*TypeAlias{ + "Struct": { + name: "Struct", + typ: &Array{&TypeString{"string"}}, + }, + }, + } + + result := ToVisitorPlan(input) + + expected := &VisitorPlan{ + Switches: []*SwitchCase{{ + Type: &Ref{&TypeString{"Struct"}}, + Fields: []VisitorItem{}, + }}, + } + + assert.Equal(t, expected.String(), result.String()) +} + +func TestSortingOfOutputs(t *testing.T) { + /* + type NodeInterface interface {} + type AStruct struct { + AField NodeInterface + BField NodeInterface + } + type BStruct struct { + CField NodeInterface + } + func (*AStruct) iNode{} + func (*BStruct) iNode{} + */ + + input := &SourceInformation{ + interfaces: map[string]bool{"NodeInterface": true}, + interestingTypes: map[string]Type{ + "AStruct": &Ref{&TypeString{"AStruct"}}, + "BStruct": &Ref{&TypeString{"BStruct"}}, + }, + structs: map[string]*StructDeclaration{ + "AStruct": {name: "AStruct", fields: []*Field{ + {name: "BField", typ: &TypeString{"NodeInterface"}}, + {name: "AField", typ: &TypeString{"NodeInterface"}}, + }}, + "BStruct": {name: "BStruct", fields: []*Field{ + {name: "CField", typ: &TypeString{"NodeInterface"}}, + }}, + }, + typeAliases: map[string]*TypeAlias{}, + } + + result := ToVisitorPlan(input) + + expected := &VisitorPlan{ + Switches: []*SwitchCase{ + {Type: &Ref{&TypeString{"AStruct"}}, + Fields: []VisitorItem{ + &SingleFieldItem{ + StructType: &Ref{&TypeString{"AStruct"}}, + FieldType: &TypeString{"NodeInterface"}, + FieldName: "AField", + }, + &SingleFieldItem{ + StructType: &Ref{&TypeString{"AStruct"}}, + FieldType: &TypeString{"NodeInterface"}, + FieldName: "BField", + }}}, + {Type: &Ref{&TypeString{"BStruct"}}, + Fields: []VisitorItem{ + &SingleFieldItem{ + StructType: &Ref{&TypeString{"BStruct"}}, + FieldType: &TypeString{"NodeInterface"}, + FieldName: "CField", + }}}}, + } + assert.Equal(t, expected.String(), result.String()) +} + +func TestAliasOfAlias(t *testing.T) { + /* + type NodeInterface interface { + iNode() + } + + type NodeArray []NodeInterface + type AliasOfAlias NodeArray + + func (NodeArray) iNode{} + func (AliasOfAlias) iNode{} + */ + + input := &SourceInformation{ + interfaces: map[string]bool{"NodeInterface": true}, + interestingTypes: map[string]Type{ + "NodeArray": &TypeString{"NodeArray"}, + "AliasOfAlias": &TypeString{"AliasOfAlias"}, + }, + structs: map[string]*StructDeclaration{}, + typeAliases: map[string]*TypeAlias{ + "NodeArray": { + name: "NodeArray", + typ: &Array{&TypeString{"NodeInterface"}}, + }, + "AliasOfAlias": { + name: "NodeArray", + typ: &TypeString{"NodeArray"}, + }, + }, + } + + result := ToVisitorPlan(input) + + expected := &VisitorPlan{ + Switches: []*SwitchCase{ + {Type: &TypeString{"AliasOfAlias"}, + Fields: []VisitorItem{&ArrayItem{ + StructType: &TypeString{"AliasOfAlias"}, + ItemType: &TypeString{"NodeInterface"}, + }}, + }, + {Type: &TypeString{"NodeArray"}, + Fields: []VisitorItem{&ArrayItem{ + StructType: &TypeString{"NodeArray"}, + ItemType: &TypeString{"NodeInterface"}, + }}, + }}, + } + assert.Equal(t, expected.String(), result.String()) +} diff --git a/go/visitorgen/transformer.go b/go/visitorgen/transformer.go new file mode 100644 index 00000000000..98129be81b1 --- /dev/null +++ b/go/visitorgen/transformer.go @@ -0,0 +1,95 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package visitorgen + +import "fmt" + +// Transform takes an input file and collects the information into an easier to consume format +func Transform(input *SourceFile) *SourceInformation { + interestingTypes := make(map[string]Type) + interfaces := make(map[string]bool) + structs := make(map[string]*StructDeclaration) + typeAliases := make(map[string]*TypeAlias) + + for _, l := range input.lines { + switch line := l.(type) { + case *FuncDeclaration: + interestingTypes[line.receiver.typ.toTypString()] = line.receiver.typ + case *StructDeclaration: + structs[line.name] = line + case *TypeAlias: + typeAliases[line.name] = line + case *InterfaceDeclaration: + interfaces[line.name] = true + } + } + + return &SourceInformation{ + interfaces: interfaces, + interestingTypes: interestingTypes, + structs: structs, + typeAliases: typeAliases, + } +} + +// SourceInformation contains the information from the ast.go file, but in a format that is easier to consume +type SourceInformation struct { + interestingTypes map[string]Type + interfaces map[string]bool + structs map[string]*StructDeclaration + typeAliases map[string]*TypeAlias +} + +func (v *SourceInformation) String() string { + var types string + for _, k := range v.interestingTypes { + types += k.toTypString() + "\n" + } + var structs string + for _, k := range v.structs { + structs += k.toSastString() + "\n" + } + var typeAliases string + for _, k := range v.typeAliases { + typeAliases += k.toSastString() + "\n" + } + + return fmt.Sprintf("Types to build visitor for:\n%s\nStructs with fields: \n%s\nTypeAliases with type: \n%s\n", types, structs, typeAliases) +} + +// getItemTypeOfArray will return nil if the given type is not pointing to a array type. +// If it is an array type, the type of it's items will be returned +func (v *SourceInformation) getItemTypeOfArray(typ Type) Type { + alias := v.typeAliases[typ.rawTypeName()] + if alias == nil { + return nil + } + arrTyp, isArray := alias.typ.(*Array) + if !isArray { + return v.getItemTypeOfArray(alias.typ) + } + return arrTyp.inner +} + +func (v *SourceInformation) isSQLNode(typ Type) bool { + _, isInteresting := v.interestingTypes[typ.toTypString()] + if isInteresting { + return true + } + _, isInterface := v.interfaces[typ.toTypString()] + return isInterface +} diff --git a/go/visitorgen/transformer_test.go b/go/visitorgen/transformer_test.go new file mode 100644 index 00000000000..4a0849e9e9c --- /dev/null +++ b/go/visitorgen/transformer_test.go @@ -0,0 +1,110 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package visitorgen + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSimplestAst(t *testing.T) { + /* + type NodeInterface interface { + iNode() + } + + type NodeStruct struct {} + + func (*NodeStruct) iNode{} + */ + input := &SourceFile{ + lines: []Sast{ + &InterfaceDeclaration{ + name: "NodeInterface", + block: "// an interface lives here"}, + &StructDeclaration{ + name: "NodeStruct", + fields: []*Field{}}, + &FuncDeclaration{ + receiver: &Field{ + name: "", + typ: &Ref{&TypeString{"NodeStruct"}}, + }, + name: "iNode", + block: "", + arguments: []*Field{}}, + }, + } + + expected := &SourceInformation{ + interestingTypes: map[string]Type{ + "*NodeStruct": &Ref{&TypeString{"NodeStruct"}}}, + structs: map[string]*StructDeclaration{ + "NodeStruct": { + name: "NodeStruct", + fields: []*Field{}}}, + } + + assert.Equal(t, expected.String(), Transform(input).String()) +} + +func TestAstWithArray(t *testing.T) { + /* + type NodeInterface interface { + iNode() + } + + func (*NodeArray) iNode{} + + type NodeArray []NodeInterface + */ + input := &SourceFile{ + lines: []Sast{ + &InterfaceDeclaration{ + name: "NodeInterface"}, + &TypeAlias{ + name: "NodeArray", + typ: &Array{&TypeString{"NodeInterface"}}, + }, + &FuncDeclaration{ + receiver: &Field{ + name: "", + typ: &Ref{&TypeString{"NodeArray"}}, + }, + name: "iNode", + block: "", + arguments: []*Field{}}, + }, + } + + expected := &SourceInformation{ + interestingTypes: map[string]Type{ + "*NodeArray": &Ref{&TypeString{"NodeArray"}}}, + structs: map[string]*StructDeclaration{}, + typeAliases: map[string]*TypeAlias{ + "NodeArray": { + name: "NodeArray", + typ: &Array{&TypeString{"NodeInterface"}}, + }, + }, + } + + result := Transform(input) + + assert.Equal(t, expected.String(), result.String()) +} diff --git a/go/visitorgen/visitor_emitter.go b/go/visitorgen/visitor_emitter.go new file mode 100644 index 00000000000..889c05fe7f7 --- /dev/null +++ b/go/visitorgen/visitor_emitter.go @@ -0,0 +1,76 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package visitorgen + +import ( + "fmt" + "strings" +) + +// EmitReplacementMethods is an anti-parser (a.k.a prettifier) - it takes a struct that is much like an AST, +// and produces a string from it. This method will produce the replacement methods that make it possible to +// replace objects in fields or in slices. +func EmitReplacementMethods(vd *VisitorPlan) string { + var sb builder + for _, s := range vd.Switches { + for _, k := range s.Fields { + sb.appendF(k.asReplMethod()) + sb.newLine() + } + } + + return sb.String() +} + +// EmitTypeSwitches is an anti-parser (a.k.a prettifier) - it takes a struct that is much like an AST, +// and produces a string from it. This method will produce the switch cases needed to cover the Vitess AST. +func EmitTypeSwitches(vd *VisitorPlan) string { + var sb builder + for _, s := range vd.Switches { + sb.newLine() + sb.appendF(" case %s:", s.Type.toTypString()) + for _, k := range s.Fields { + sb.appendF(k.asSwitchCase()) + } + } + + return sb.String() +} + +func (b *builder) String() string { + return strings.TrimSpace(b.sb.String()) +} + +type builder struct { + sb strings.Builder +} + +func (b *builder) appendF(format string, data ...interface{}) *builder { + _, err := b.sb.WriteString(fmt.Sprintf(format, data...)) + if err != nil { + panic(err) + } + b.newLine() + return b +} + +func (b *builder) newLine() { + _, err := b.sb.WriteString("\n") + if err != nil { + panic(err) + } +} diff --git a/go/visitorgen/visitor_emitter_test.go b/go/visitorgen/visitor_emitter_test.go new file mode 100644 index 00000000000..94666daa743 --- /dev/null +++ b/go/visitorgen/visitor_emitter_test.go @@ -0,0 +1,92 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package visitorgen + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSingleItem(t *testing.T) { + sfi := SingleFieldItem{ + StructType: &Ref{&TypeString{"Struct"}}, + FieldType: &TypeString{"string"}, + FieldName: "Field", + } + + expectedReplacer := `func replaceStructField(newNode, parent SQLNode) { + parent.(*Struct).Field = newNode.(string) +}` + + expectedSwitch := ` a.apply(node, n.Field, replaceStructField)` + require.Equal(t, expectedReplacer, sfi.asReplMethod()) + require.Equal(t, expectedSwitch, sfi.asSwitchCase()) +} + +func TestArrayFieldItem(t *testing.T) { + sfi := ArrayFieldItem{ + StructType: &Ref{&TypeString{"Struct"}}, + ItemType: &TypeString{"string"}, + FieldName: "Field", + } + + expectedReplacer := `type replaceStructField int + +func (r *replaceStructField) replace(newNode, container SQLNode) { + container.(*Struct).Field[int(*r)] = newNode.(string) +} + +func (r *replaceStructField) inc() { + *r++ +}` + + expectedSwitch := ` replacerField := replaceStructField(0) + replacerFieldB := &replacerField + for _, item := range n.Field { + a.apply(node, item, replacerFieldB.replace) + replacerFieldB.inc() + }` + require.Equal(t, expectedReplacer, sfi.asReplMethod()) + require.Equal(t, expectedSwitch, sfi.asSwitchCase()) +} + +func TestArrayItem(t *testing.T) { + sfi := ArrayItem{ + StructType: &Ref{&TypeString{"Struct"}}, + ItemType: &TypeString{"string"}, + } + + expectedReplacer := `type replaceStructItems int + +func (r *replaceStructItems) replace(newNode, container SQLNode) { + container.(*Struct)[int(*r)] = newNode.(string) +} + +func (r *replaceStructItems) inc() { + *r++ +}` + + expectedSwitch := ` replacer := replaceStructItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + }` + require.Equal(t, expectedReplacer, sfi.asReplMethod()) + require.Equal(t, expectedSwitch, sfi.asSwitchCase()) +} diff --git a/go/visitorgen/visitorgen.go b/go/visitorgen/visitorgen.go new file mode 100644 index 00000000000..284f8c4d9be --- /dev/null +++ b/go/visitorgen/visitorgen.go @@ -0,0 +1,33 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//Package visitorgen is responsible for taking the ast.go of Vitess and +//and producing visitor infrastructure for it. +// +//This is accomplished in a few steps. +//Step 1: Walk the AST and collect the interesting information into a format that is +// easy to consume for the next step. The output format is a *SourceFile, that +// contains the needed information in a format that is pretty close to the golang ast, +// but simplified +//Step 2: A SourceFile is packaged into a SourceInformation. SourceInformation is still +// concerned with the input ast - it's just an even more distilled and easy to +// consume format for the last step. This step is performed by the code in transformer.go. +//Step 3: Using the SourceInformation, the struct_producer.go code produces the final data structure +// used, a VisitorPlan. This is focused on the output - it contains a list of all fields or +// arrays that need to be handled by the visitor produced. +//Step 4: The VisitorPlan is lastly turned into a string that is written as the output of +// this whole process. +package visitorgen diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 58dad3940bc..68de87b3a29 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -17,470 +17,812 @@ limitations under the License. package sqlparser import ( - "encoding/hex" - "encoding/json" - "errors" "fmt" - "io" "strings" - "sync" "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/log" - "vitess.io/vitess/go/vt/vterrors" - - querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) -// parserPool is a pool for parser objects. -var parserPool = sync.Pool{} - -// zeroParser is a zero-initialized parser to help reinitialize the parser for pooling. -var zeroParser = *(yyNewParser().(*yyParserImpl)) - -// yyParsePooled is a wrapper around yyParse that pools the parser objects. There isn't a -// particularly good reason to use yyParse directly, since it immediately discards its parser. What -// would be ideal down the line is to actually pool the stacks themselves rather than the parser -// objects, as per https://github.com/cznic/goyacc/blob/master/main.go. However, absent an upstream -// change to goyacc, this is the next best option. -// -// N.B: Parser pooling means that you CANNOT take references directly to parse stack variables (e.g. -// $$ = &$4) in sql.y rules. You must instead add an intermediate reference like so: -// showCollationFilterOpt := $4 -// $$ = &Show{Type: string($2), ShowCollationFilterOpt: &showCollationFilterOpt} -func yyParsePooled(yylex yyLexer) int { - // Being very particular about using the base type and not an interface type b/c we depend on - // the implementation to know how to reinitialize the parser. - var parser *yyParserImpl - - i := parserPool.Get() - if i != nil { - parser = i.(*yyParserImpl) - } else { - parser = yyNewParser().(*yyParserImpl) - } - - defer func() { - *parser = zeroParser - parserPool.Put(parser) - }() - return parser.Parse(yylex) -} - -// Instructions for creating new types: If a type -// needs to satisfy an interface, declare that function -// along with that interface. This will help users -// identify the list of types to which they can assert -// those interfaces. -// If the member of a type has a string with a predefined -// list of values, declare those values as const following -// the type. -// For interfaces that define dummy functions to consolidate -// a set of types, define the function as iTypeName. -// This will help avoid name collisions. - -// Parse parses the SQL in full and returns a Statement, which -// is the AST representation of the query. If a DDL statement -// is partially parsed but still contains a syntax error, the -// error is ignored and the DDL is returned anyway. -func Parse(sql string) (Statement, error) { - tokenizer := NewStringTokenizer(sql) - if yyParsePooled(tokenizer) != 0 { - if tokenizer.partialDDL != nil { - if typ, val := tokenizer.Scan(); typ != 0 { - return nil, fmt.Errorf("extra characters encountered after end of DDL: '%s'", string(val)) - } - log.Warningf("ignoring error parsing DDL '%s': %v", sql, tokenizer.LastError) - tokenizer.ParseTree = tokenizer.partialDDL - return tokenizer.ParseTree, nil - } - return nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, tokenizer.LastError.Error()) - } - if tokenizer.ParseTree == nil { - return nil, ErrEmpty - } - return tokenizer.ParseTree, nil +/* +This is the Vitess AST. This file should only contain pure struct declarations, +or methods used to mark a struct as implementing an interface. All other methods +related to these structs live in ast_funcs.go +*/ + +// SQLNode defines the interface for all nodes +// generated by the parser. +type SQLNode interface { + Format(buf *TrackedBuffer) } -// ParseStrictDDL is the same as Parse except it errors on -// partially parsed DDL statements. -func ParseStrictDDL(sql string) (Statement, error) { - tokenizer := NewStringTokenizer(sql) - if yyParsePooled(tokenizer) != 0 { - return nil, tokenizer.LastError - } - if tokenizer.ParseTree == nil { - return nil, ErrEmpty - } - return tokenizer.ParseTree, nil +// Statements +type ( + // Statement represents a statement. + Statement interface { + iStatement() + SQLNode + } + + // SelectStatement any SELECT statement. + SelectStatement interface { + iSelectStatement() + iStatement() + iInsertRows() + AddOrder(*Order) + SetLimit(*Limit) + SQLNode + } + + // Select represents a SELECT statement. + Select struct { + Cache string + Comments Comments + Distinct string + Hints string + SelectExprs SelectExprs + From TableExprs + Where *Where + GroupBy GroupBy + Having *Where + OrderBy OrderBy + Limit *Limit + Lock string + } + + // Union represents a UNION statement. + Union struct { + Type string + Left, Right SelectStatement + OrderBy OrderBy + Limit *Limit + Lock string + } + + // Stream represents a SELECT statement. + Stream struct { + Comments Comments + SelectExpr SelectExpr + Table TableName + } + + // Insert represents an INSERT or REPLACE statement. + // Per the MySQL docs, http://dev.mysql.com/doc/refman/5.7/en/replace.html + // Replace is the counterpart to `INSERT IGNORE`, and works exactly like a + // normal INSERT except if the row exists. In that case it first deletes + // the row and re-inserts with new values. For that reason we keep it as an Insert struct. + // Replaces are currently disallowed in sharded schemas because + // of the implications the deletion part may have on vindexes. + // If you add fields here, consider adding them to calls to validateUnshardedRoute. + Insert struct { + Action string + Comments Comments + Ignore string + Table TableName + Partitions Partitions + Columns Columns + Rows InsertRows + OnDup OnDup + } + + // Update represents an UPDATE statement. + // If you add fields here, consider adding them to calls to validateUnshardedRoute. + Update struct { + Comments Comments + Ignore string + TableExprs TableExprs + Exprs UpdateExprs + Where *Where + OrderBy OrderBy + Limit *Limit + } + + // Delete represents a DELETE statement. + // If you add fields here, consider adding them to calls to validateUnshardedRoute. + Delete struct { + Comments Comments + Targets TableNames + TableExprs TableExprs + Partitions Partitions + Where *Where + OrderBy OrderBy + Limit *Limit + } + + // Set represents a SET statement. + Set struct { + Comments Comments + Exprs SetExprs + Scope string + } + + // DBDDL represents a CREATE, DROP, or ALTER database statement. + DBDDL struct { + Action string + DBName string + IfExists bool + Collate string + Charset string + } + + // DDL represents a CREATE, ALTER, DROP, RENAME, TRUNCATE or ANALYZE statement. + DDL struct { + Action string + + // FromTables is set if Action is RenameStr or DropStr. + FromTables TableNames + + // ToTables is set if Action is RenameStr. + ToTables TableNames + + // Table is set if Action is other than RenameStr or DropStr. + Table TableName + + // The following fields are set if a DDL was fully analyzed. + IfExists bool + TableSpec *TableSpec + OptLike *OptLike + PartitionSpec *PartitionSpec + + // VindexSpec is set for CreateVindexStr, DropVindexStr, AddColVindexStr, DropColVindexStr. + VindexSpec *VindexSpec + + // VindexCols is set for AddColVindexStr. + VindexCols []ColIdent + + // AutoIncSpec is set for AddAutoIncStr. + AutoIncSpec *AutoIncSpec + } + + // ParenSelect is a parenthesized SELECT statement. + ParenSelect struct { + Select SelectStatement + } + + // Show represents a show statement. + Show struct { + Type string + OnTable TableName + Table TableName + ShowTablesOpt *ShowTablesOpt + Scope string + ShowCollationFilterOpt *Expr // TODO: this should not be a pointer + } + + // Use represents a use statement. + Use struct { + DBName TableIdent + } + + // Begin represents a Begin statement. + Begin struct{} + + // Commit represents a Commit statement. + Commit struct{} + + // Rollback represents a Rollback statement. + Rollback struct{} + + // OtherRead represents a DESCRIBE, or EXPLAIN statement. + // It should be used only as an indicator. It does not contain + // the full AST for the statement. + OtherRead struct{} + + // OtherAdmin represents a misc statement that relies on ADMIN privileges, + // such as REPAIR, OPTIMIZE, or TRUNCATE statement. + // It should be used only as an indicator. It does not contain + // the full AST for the statement. + OtherAdmin struct{} +) + +func (*Union) iStatement() {} +func (*Select) iStatement() {} +func (*Stream) iStatement() {} +func (*Insert) iStatement() {} +func (*Update) iStatement() {} +func (*Delete) iStatement() {} +func (*Set) iStatement() {} +func (*DBDDL) iStatement() {} +func (*DDL) iStatement() {} +func (*Show) iStatement() {} +func (*Use) iStatement() {} +func (*Begin) iStatement() {} +func (*Commit) iStatement() {} +func (*Rollback) iStatement() {} +func (*OtherRead) iStatement() {} +func (*OtherAdmin) iStatement() {} +func (*Select) iSelectStatement() {} +func (*Union) iSelectStatement() {} +func (*ParenSelect) iSelectStatement() {} + +// ParenSelect can actually not be a top level statement, +// but we have to allow it because it's a requirement +// of SelectStatement. +func (*ParenSelect) iStatement() {} + +// InsertRows represents the rows for an INSERT statement. +type InsertRows interface { + iInsertRows() + SQLNode } -// ParseTokenizer is a raw interface to parse from the given tokenizer. -// This does not used pooled parsers, and should not be used in general. -func ParseTokenizer(tokenizer *Tokenizer) int { - return yyParse(tokenizer) +func (*Select) iInsertRows() {} +func (*Union) iInsertRows() {} +func (Values) iInsertRows() {} +func (*ParenSelect) iInsertRows() {} + +// OptLike works for create table xxx like xxx +type OptLike struct { + LikeTable TableName } -// ParseNext parses a single SQL statement from the tokenizer -// returning a Statement which is the AST representation of the query. -// The tokenizer will always read up to the end of the statement, allowing for -// the next call to ParseNext to parse any subsequent SQL statements. When -// there are no more statements to parse, a error of io.EOF is returned. -func ParseNext(tokenizer *Tokenizer) (Statement, error) { - return parseNext(tokenizer, false) +// PartitionSpec describe partition actions (for alter and create) +type PartitionSpec struct { + Action string + Name ColIdent + Definitions []*PartitionDefinition } -// ParseNextStrictDDL is the same as ParseNext except it errors on -// partially parsed DDL statements. -func ParseNextStrictDDL(tokenizer *Tokenizer) (Statement, error) { - return parseNext(tokenizer, true) +// PartitionDefinition describes a very minimal partition definition +type PartitionDefinition struct { + Name ColIdent + Limit Expr + Maxvalue bool } -func parseNext(tokenizer *Tokenizer, strict bool) (Statement, error) { - if tokenizer.lastChar == ';' { - tokenizer.next() - tokenizer.skipBlank() - } - if tokenizer.lastChar == eofChar { - return nil, io.EOF - } +// TableSpec describes the structure of a table from a CREATE TABLE statement +type TableSpec struct { + Columns []*ColumnDefinition + Indexes []*IndexDefinition + Constraints []*ConstraintDefinition + Options string +} - tokenizer.reset() - tokenizer.multi = true - if yyParsePooled(tokenizer) != 0 { - if tokenizer.partialDDL != nil && !strict { - tokenizer.ParseTree = tokenizer.partialDDL - return tokenizer.ParseTree, nil - } - return nil, tokenizer.LastError - } - if tokenizer.ParseTree == nil { - return ParseNext(tokenizer) - } - return tokenizer.ParseTree, nil +// ColumnDefinition describes a column in a CREATE TABLE statement +type ColumnDefinition struct { + Name ColIdent + // TODO: Should this not be a reference? + Type ColumnType } -// ErrEmpty is a sentinel error returned when parsing empty statements. -var ErrEmpty = errors.New("empty statement") +// ColumnType represents a sql type in a CREATE TABLE statement +// All optional fields are nil if not specified +type ColumnType struct { + // The base type string + Type string -// SplitStatement returns the first sql statement up to either a ; or EOF -// and the remainder from the given buffer -func SplitStatement(blob string) (string, string, error) { - tokenizer := NewStringTokenizer(blob) - tkn := 0 - for { - tkn, _ = tokenizer.Scan() - if tkn == 0 || tkn == ';' || tkn == eofChar { - break - } - } - if tokenizer.LastError != nil { - return "", "", tokenizer.LastError - } - if tkn == ';' { - return blob[:tokenizer.Position-2], blob[tokenizer.Position-1:], nil - } - return blob, "", nil -} + // Generic field options. + NotNull BoolVal + Autoincrement BoolVal + Default Expr + OnUpdate Expr + Comment *SQLVal -// SplitStatementToPieces split raw sql statement that may have multi sql pieces to sql pieces -// returns the sql pieces blob contains; or error if sql cannot be parsed -func SplitStatementToPieces(blob string) (pieces []string, err error) { - pieces = make([]string, 0, 16) - tokenizer := NewStringTokenizer(blob) + // Numeric field options + Length *SQLVal + Unsigned BoolVal + Zerofill BoolVal + Scale *SQLVal - tkn := 0 - var stmt string - stmtBegin := 0 - for { - tkn, _ = tokenizer.Scan() - if tkn == ';' { - stmt = blob[stmtBegin : tokenizer.Position-2] - pieces = append(pieces, stmt) - stmtBegin = tokenizer.Position - 1 + // Text field options + Charset string + Collate string - } else if tkn == 0 || tkn == eofChar { - blobTail := tokenizer.Position - 2 + // Enum values + EnumValues []string - if stmtBegin < blobTail { - stmt = blob[stmtBegin : blobTail+1] - pieces = append(pieces, stmt) - } - break - } - } + // Key specification + KeyOpt ColumnKeyOption +} - err = tokenizer.LastError - return +// IndexDefinition describes an index in a CREATE TABLE statement +type IndexDefinition struct { + Info *IndexInfo + Columns []*IndexColumn + Options []*IndexOption } -// SQLNode defines the interface for all nodes -// generated by the parser. -type SQLNode interface { - Format(buf *TrackedBuffer) - // walkSubtree calls visit on all underlying nodes - // of the subtree, but not the current one. Walking - // must be interrupted if visit returns an error. - walkSubtree(visit Visit) error -} - -// Visit defines the signature of a function that -// can be used to visit all nodes of a parse tree. -type Visit func(node SQLNode) (kontinue bool, err error) - -// Walk calls visit on every node. -// If visit returns true, the underlying nodes -// are also visited. If it returns an error, walking -// is interrupted, and the error is returned. -func Walk(visit Visit, nodes ...SQLNode) error { - for _, node := range nodes { - if node == nil { - continue - } - kontinue, err := visit(node) - if err != nil { - return err - } - if kontinue { - err = node.walkSubtree(visit) - if err != nil { - return err - } - } - } - return nil +// IndexInfo describes the name and type of an index in a CREATE TABLE statement +type IndexInfo struct { + Type string + Name ColIdent + Primary bool + Spatial bool + Unique bool } -// String returns a string representation of an SQLNode. -func String(node SQLNode) string { - if node == nil { - return "" - } +// VindexSpec defines a vindex for a CREATE VINDEX or DROP VINDEX statement +type VindexSpec struct { + Name ColIdent + Type ColIdent + Params []VindexParam +} - buf := NewTrackedBuffer(nil) - buf.Myprintf("%v", node) - return buf.String() +// AutoIncSpec defines and autoincrement value for a ADD AUTO_INCREMENT statement +type AutoIncSpec struct { + Column ColIdent + Sequence TableName } -// Append appends the SQLNode to the buffer. -func Append(buf *strings.Builder, node SQLNode) { - tbuf := &TrackedBuffer{ - Builder: buf, - } - node.Format(tbuf) +// VindexParam defines a key/value parameter for a CREATE VINDEX statement +type VindexParam struct { + Key ColIdent + Val string } -// Statement represents a statement. -type Statement interface { - iStatement() - SQLNode +// ConstraintDefinition describes a constraint in a CREATE TABLE statement +type ConstraintDefinition struct { + Name string + Details ConstraintInfo } -func (*Union) iStatement() {} -func (*Select) iStatement() {} -func (*Stream) iStatement() {} -func (*Insert) iStatement() {} -func (*Update) iStatement() {} -func (*Delete) iStatement() {} -func (*Set) iStatement() {} -func (*DBDDL) iStatement() {} -func (*DDL) iStatement() {} -func (*Show) iStatement() {} -func (*Use) iStatement() {} -func (*Begin) iStatement() {} -func (*Commit) iStatement() {} -func (*Rollback) iStatement() {} -func (*OtherRead) iStatement() {} -func (*OtherAdmin) iStatement() {} +type ( + // ConstraintInfo details a constraint in a CREATE TABLE statement + ConstraintInfo interface { + SQLNode + iConstraintInfo() + } -// ParenSelect can actually not be a top level statement, -// but we have to allow it because it's a requirement -// of SelectStatement. -func (*ParenSelect) iStatement() {} + // ForeignKeyDefinition describes a foreign key in a CREATE TABLE statement + ForeignKeyDefinition struct { + Source Columns + ReferencedTable TableName + ReferencedColumns Columns + OnDelete ReferenceAction + OnUpdate ReferenceAction + } +) -// SelectStatement any SELECT statement. -type SelectStatement interface { - iSelectStatement() - iStatement() - iInsertRows() - AddOrder(*Order) - SetLimit(*Limit) - SQLNode +// ShowFilter is show tables filter +type ShowFilter struct { + Like string + Filter Expr } -func (*Select) iSelectStatement() {} -func (*Union) iSelectStatement() {} -func (*ParenSelect) iSelectStatement() {} +// Comments represents a list of comments. +type Comments [][]byte + +// SelectExprs represents SELECT expressions. +type SelectExprs []SelectExpr + +type ( + // SelectExpr represents a SELECT expression. + SelectExpr interface { + iSelectExpr() + SQLNode + } -// Select represents a SELECT statement. -type Select struct { - Cache string - Comments Comments - Distinct string - Hints string - SelectExprs SelectExprs - From TableExprs - Where *Where - GroupBy GroupBy - Having *Where - OrderBy OrderBy - Limit *Limit - Lock string -} - -// Select.Distinct -const ( - DistinctStr = "distinct " - StraightJoinHint = "straight_join " + // StarExpr defines a '*' or 'table.*' expression. + StarExpr struct { + TableName TableName + } + + // AliasedExpr defines an aliased SELECT expression. + AliasedExpr struct { + Expr Expr + As ColIdent + } + + // Nextval defines the NEXT VALUE expression. + Nextval struct { + Expr Expr + } ) -// Select.Lock -const ( - ForUpdateStr = " for update" - ShareModeStr = " lock in share mode" +func (*StarExpr) iSelectExpr() {} +func (*AliasedExpr) iSelectExpr() {} +func (Nextval) iSelectExpr() {} + +// Columns represents an insert column list. +type Columns []ColIdent + +// Partitions is a type alias for Columns so we can handle printing efficiently +type Partitions Columns + +// TableExprs represents a list of table expressions. +type TableExprs []TableExpr + +type ( + // TableExpr represents a table expression. + TableExpr interface { + iTableExpr() + SQLNode + } + + // AliasedTableExpr represents a table expression + // coupled with an optional alias or index hint. + // If As is empty, no alias was used. + AliasedTableExpr struct { + Expr SimpleTableExpr + Partitions Partitions + As TableIdent + Hints *IndexHints + } + + // JoinTableExpr represents a TableExpr that's a JOIN operation. + JoinTableExpr struct { + LeftExpr TableExpr + Join string + RightExpr TableExpr + Condition JoinCondition + } + + // ParenTableExpr represents a parenthesized list of TableExpr. + ParenTableExpr struct { + Exprs TableExprs + } ) -// Select.Cache -const ( - SQLCacheStr = "sql_cache " - SQLNoCacheStr = "sql_no_cache " +func (*AliasedTableExpr) iTableExpr() {} +func (*ParenTableExpr) iTableExpr() {} +func (*JoinTableExpr) iTableExpr() {} + +type ( + // SimpleTableExpr represents a simple table expression. + SimpleTableExpr interface { + iSimpleTableExpr() + SQLNode + } + + // TableName represents a table name. + // Qualifier, if specified, represents a database or keyspace. + // TableName is a value struct whose fields are case sensitive. + // This means two TableName vars can be compared for equality + // and a TableName can also be used as key in a map. + TableName struct { + Name, Qualifier TableIdent + } + + // Subquery represents a subquery. + Subquery struct { + Select SelectStatement + } ) -// AddOrder adds an order by element -func (node *Select) AddOrder(order *Order) { - node.OrderBy = append(node.OrderBy, order) +func (TableName) iSimpleTableExpr() {} +func (*Subquery) iSimpleTableExpr() {} + +// TableNames is a list of TableName. +type TableNames []TableName + +// JoinCondition represents the join conditions (either a ON or USING clause) +// of a JoinTableExpr. +type JoinCondition struct { + On Expr + Using Columns } -// SetLimit sets the limit clause -func (node *Select) SetLimit(limit *Limit) { - node.Limit = limit +// IndexHints represents a list of index hints. +type IndexHints struct { + Type string + Indexes []ColIdent } -// Format formats the node. -func (node *Select) Format(buf *TrackedBuffer) { - buf.Myprintf("select %v%s%s%s%v from %v%v%v%v%v%v%s", - node.Comments, node.Cache, node.Distinct, node.Hints, node.SelectExprs, - node.From, node.Where, - node.GroupBy, node.Having, node.OrderBy, - node.Limit, node.Lock) +// Where represents a WHERE or HAVING clause. +type Where struct { + Type string + Expr Expr } -func (node *Select) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Comments, - node.SelectExprs, - node.From, - node.Where, - node.GroupBy, - node.Having, - node.OrderBy, - node.Limit, - ) -} - -// AddWhere adds the boolean expression to the -// WHERE clause as an AND condition. If the expression -// is an OR clause, it parenthesizes it. Currently, -// the OR operator is the only one that's lower precedence -// than AND. -func (node *Select) AddWhere(expr Expr) { - if _, ok := expr.(*OrExpr); ok { - expr = &ParenExpr{Expr: expr} - } - if node.Where == nil { - node.Where = &Where{ - Type: WhereStr, - Expr: expr, - } - return +// *********** Expressions +type ( + // Expr represents an expression. + Expr interface { + iExpr() + SQLNode } - node.Where.Expr = &AndExpr{ - Left: node.Where.Expr, - Right: expr, + + // AndExpr represents an AND expression. + AndExpr struct { + Left, Right Expr } -} -// AddHaving adds the boolean expression to the -// HAVING clause as an AND condition. If the expression -// is an OR clause, it parenthesizes it. Currently, -// the OR operator is the only one that's lower precedence -// than AND. -func (node *Select) AddHaving(expr Expr) { - if _, ok := expr.(*OrExpr); ok { - expr = &ParenExpr{Expr: expr} + // OrExpr represents an OR expression. + OrExpr struct { + Left, Right Expr } - if node.Having == nil { - node.Having = &Where{ - Type: HavingStr, - Expr: expr, - } - return + + // NotExpr represents a NOT expression. + NotExpr struct { + Expr Expr + } + + // ParenExpr represents a parenthesized boolean expression. + ParenExpr struct { + Expr Expr + } + + // ComparisonExpr represents a two-value comparison expression. + ComparisonExpr struct { + Operator string + Left, Right Expr + Escape Expr + } + + // RangeCond represents a BETWEEN or a NOT BETWEEN expression. + RangeCond struct { + Operator string + Left Expr + From, To Expr + } + + // IsExpr represents an IS ... or an IS NOT ... expression. + IsExpr struct { + Operator string + Expr Expr + } + + // ExistsExpr represents an EXISTS expression. + ExistsExpr struct { + Subquery *Subquery + } + + // SQLVal represents a single value. + SQLVal struct { + Type ValType + Val []byte + } + + // NullVal represents a NULL value. + NullVal struct{} + + // BoolVal is true or false. + BoolVal bool + + // ColName represents a column name. + ColName struct { + // Metadata is not populated by the parser. + // It's a placeholder for analyzers to store + // additional data, typically info about which + // table or column this node references. + Metadata interface{} + Name ColIdent + Qualifier TableName + } + + // ColTuple represents a list of column values. + // It can be ValTuple, Subquery, ListArg. + ColTuple interface { + iColTuple() + Expr + } + + // ListArg represents a named list argument. + ListArg []byte + + // ValTuple represents a tuple of actual values. + ValTuple Exprs + + // BinaryExpr represents a binary value expression. + BinaryExpr struct { + Operator string + Left, Right Expr + } + + // UnaryExpr represents a unary value expression. + UnaryExpr struct { + Operator string + Expr Expr + } + + // IntervalExpr represents a date-time INTERVAL expression. + IntervalExpr struct { + Expr Expr + Unit string + } + + // TimestampFuncExpr represents the function and arguments for TIMESTAMP{ADD,DIFF} functions. + TimestampFuncExpr struct { + Name string + Expr1 Expr + Expr2 Expr + Unit string + } + + // CollateExpr represents dynamic collate operator. + CollateExpr struct { + Expr Expr + Charset string } - node.Having.Expr = &AndExpr{ - Left: node.Having.Expr, - Right: expr, + + // FuncExpr represents a function call. + FuncExpr struct { + Qualifier TableIdent + Name ColIdent + Distinct bool + Exprs SelectExprs + } + + // GroupConcatExpr represents a call to GROUP_CONCAT + GroupConcatExpr struct { + Distinct string + Exprs SelectExprs + OrderBy OrderBy + Separator string + } + + // ValuesFuncExpr represents a function call. + ValuesFuncExpr struct { + Name *ColName + } + + // SubstrExpr represents a call to SubstrExpr(column, value_expression) or SubstrExpr(column, value_expression,value_expression) + // also supported syntax SubstrExpr(column from value_expression for value_expression). + // Additionally to column names, SubstrExpr is also supported for string values, e.g.: + // SubstrExpr('static string value', value_expression, value_expression) + // In this case StrVal will be set instead of Name. + SubstrExpr struct { + Name *ColName + StrVal *SQLVal + From Expr + To Expr } + + // ConvertExpr represents a call to CONVERT(expr, type) + // or it's equivalent CAST(expr AS type). Both are rewritten to the former. + ConvertExpr struct { + Expr Expr + Type *ConvertType + } + + // ConvertUsingExpr represents a call to CONVERT(expr USING charset). + ConvertUsingExpr struct { + Expr Expr + Type string + } + + // MatchExpr represents a call to the MATCH function + MatchExpr struct { + Columns SelectExprs + Expr Expr + Option string + } + + // CaseExpr represents a CASE expression. + CaseExpr struct { + Expr Expr + Whens []*When + Else Expr + } + + // Default represents a DEFAULT expression. + Default struct { + ColName string + } + + // When represents a WHEN sub-expression. + When struct { + Cond Expr + Val Expr + } + + // CurTimeFuncExpr represents the function and arguments for CURRENT DATE/TIME functions + // supported functions are documented in the grammar + CurTimeFuncExpr struct { + Name ColIdent + Fsp Expr // fractional seconds precision, integer from 0 to 6 + } +) + +// iExpr ensures that only expressions nodes can be assigned to a Expr +func (*AndExpr) iExpr() {} +func (*OrExpr) iExpr() {} +func (*NotExpr) iExpr() {} +func (*ParenExpr) iExpr() {} +func (*ComparisonExpr) iExpr() {} +func (*RangeCond) iExpr() {} +func (*IsExpr) iExpr() {} +func (*ExistsExpr) iExpr() {} +func (*SQLVal) iExpr() {} +func (*NullVal) iExpr() {} +func (BoolVal) iExpr() {} +func (*ColName) iExpr() {} +func (ValTuple) iExpr() {} +func (*Subquery) iExpr() {} +func (ListArg) iExpr() {} +func (*BinaryExpr) iExpr() {} +func (*UnaryExpr) iExpr() {} +func (*IntervalExpr) iExpr() {} +func (*CollateExpr) iExpr() {} +func (*FuncExpr) iExpr() {} +func (*TimestampFuncExpr) iExpr() {} +func (*CurTimeFuncExpr) iExpr() {} +func (*CaseExpr) iExpr() {} +func (*ValuesFuncExpr) iExpr() {} +func (*ConvertExpr) iExpr() {} +func (*SubstrExpr) iExpr() {} +func (*ConvertUsingExpr) iExpr() {} +func (*MatchExpr) iExpr() {} +func (*GroupConcatExpr) iExpr() {} +func (*Default) iExpr() {} + +// Exprs represents a list of value expressions. +// It's not a valid expression because it's not parenthesized. +type Exprs []Expr + +func (ValTuple) iColTuple() {} +func (*Subquery) iColTuple() {} +func (ListArg) iColTuple() {} + +// ConvertType represents the type in call to CONVERT(expr, type) +type ConvertType struct { + Type string + Length *SQLVal + Scale *SQLVal + Operator string + Charset string } -// ParenSelect is a parenthesized SELECT statement. -type ParenSelect struct { - Select SelectStatement +// GroupBy represents a GROUP BY clause. +type GroupBy []Expr + +// OrderBy represents an ORDER By clause. +type OrderBy []*Order + +// Order represents an ordering expression. +type Order struct { + Expr Expr + Direction string } -// AddOrder adds an order by element -func (node *ParenSelect) AddOrder(order *Order) { - panic("unreachable") +// Limit represents a LIMIT clause. +type Limit struct { + Offset, Rowcount Expr } -// SetLimit sets the limit clause -func (node *ParenSelect) SetLimit(limit *Limit) { - panic("unreachable") +// Values represents a VALUES clause. +type Values []ValTuple + +// UpdateExprs represents a list of update expressions. +type UpdateExprs []*UpdateExpr + +// UpdateExpr represents an update expression. +type UpdateExpr struct { + Name *ColName + Expr Expr } -// Format formats the node. -func (node *ParenSelect) Format(buf *TrackedBuffer) { - buf.Myprintf("(%v)", node.Select) +// SetExprs represents a list of set expressions. +type SetExprs []*SetExpr + +// SetExpr represents a set expression. +type SetExpr struct { + Name ColIdent + Expr Expr } -func (node *ParenSelect) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Select, - ) +// OnDup represents an ON DUPLICATE KEY clause. +type OnDup UpdateExprs + +// ColIdent is a case insensitive SQL identifier. It will be escaped with +// backquotes if necessary. +type ColIdent struct { + // This artifact prevents this struct from being compared + // with itself. It consumes no space as long as it's not the + // last field in the struct. + _ [0]struct{ _ []byte } + val, lowered string } -// Union represents a UNION statement. -type Union struct { - Type string - Left, Right SelectStatement - OrderBy OrderBy - Limit *Limit - Lock string +// TableIdent is a case sensitive SQL identifier. It will be escaped with +// backquotes if necessary. +type TableIdent struct { + v string } -// Union.Type -const ( - UnionStr = "union" - UnionAllStr = "union all" - UnionDistinctStr = "union distinct" -) +// Here follow all the Format implementations for AST nodes -// AddOrder adds an order by element -func (node *Union) AddOrder(order *Order) { - node.OrderBy = append(node.OrderBy, order) +// Format formats the node. +func (node *Select) Format(buf *TrackedBuffer) { + buf.Myprintf("select %v%s%s%s%v from %v%v%v%v%v%v%s", + node.Comments, node.Cache, node.Distinct, node.Hints, node.SelectExprs, + node.From, node.Where, + node.GroupBy, node.Having, node.OrderBy, + node.Limit, node.Lock) } -// SetLimit sets the limit clause -func (node *Union) SetLimit(limit *Limit) { - node.Limit = limit +// Format formats the node. +func (node *ParenSelect) Format(buf *TrackedBuffer) { + buf.Myprintf("(%v)", node.Select) } // Format formats the node. @@ -489,110 +831,18 @@ func (node *Union) Format(buf *TrackedBuffer) { node.OrderBy, node.Limit, node.Lock) } -func (node *Union) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Left, - node.Right, - ) -} - -// Stream represents a SELECT statement. -type Stream struct { - Comments Comments - SelectExpr SelectExpr - Table TableName -} - // Format formats the node. func (node *Stream) Format(buf *TrackedBuffer) { buf.Myprintf("stream %v%v from %v", node.Comments, node.SelectExpr, node.Table) } -func (node *Stream) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Comments, - node.SelectExpr, - node.Table, - ) -} - -// Insert represents an INSERT or REPLACE statement. -// Per the MySQL docs, http://dev.mysql.com/doc/refman/5.7/en/replace.html -// Replace is the counterpart to `INSERT IGNORE`, and works exactly like a -// normal INSERT except if the row exists. In that case it first deletes -// the row and re-inserts with new values. For that reason we keep it as an Insert struct. -// Replaces are currently disallowed in sharded schemas because -// of the implications the deletion part may have on vindexes. -// If you add fields here, consider adding them to calls to validateUnshardedRoute. -type Insert struct { - Action string - Comments Comments - Ignore string - Table TableName - Partitions Partitions - Columns Columns - Rows InsertRows - OnDup OnDup -} - -// DDL strings. -const ( - InsertStr = "insert" - ReplaceStr = "replace" -) - // Format formats the node. func (node *Insert) Format(buf *TrackedBuffer) { buf.Myprintf("%s %v%sinto %v%v%v %v%v", node.Action, - node.Comments, node.Ignore, - node.Table, node.Partitions, node.Columns, node.Rows, node.OnDup) -} - -func (node *Insert) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Comments, - node.Table, - node.Columns, - node.Rows, - node.OnDup, - ) -} - -// InsertRows represents the rows for an INSERT statement. -type InsertRows interface { - iInsertRows() - SQLNode -} - -func (*Select) iInsertRows() {} -func (*Union) iInsertRows() {} -func (Values) iInsertRows() {} -func (*ParenSelect) iInsertRows() {} - -// Update represents an UPDATE statement. -// If you add fields here, consider adding them to calls to validateUnshardedRoute. -type Update struct { - Comments Comments - Ignore string - TableExprs TableExprs - Exprs UpdateExprs - Where *Where - OrderBy OrderBy - Limit *Limit + node.Comments, node.Ignore, + node.Table, node.Partitions, node.Columns, node.Rows, node.OnDup) } // Format formats the node. @@ -602,33 +852,6 @@ func (node *Update) Format(buf *TrackedBuffer) { node.Exprs, node.Where, node.OrderBy, node.Limit) } -func (node *Update) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Comments, - node.TableExprs, - node.Exprs, - node.Where, - node.OrderBy, - node.Limit, - ) -} - -// Delete represents a DELETE statement. -// If you add fields here, consider adding them to calls to validateUnshardedRoute. -type Delete struct { - Comments Comments - Targets TableNames - TableExprs TableExprs - Partitions Partitions - Where *Where - OrderBy OrderBy - Limit *Limit -} - // Format formats the node. func (node *Delete) Format(buf *TrackedBuffer) { buf.Myprintf("delete %v", node.Comments) @@ -638,36 +861,6 @@ func (node *Delete) Format(buf *TrackedBuffer) { buf.Myprintf("from %v%v%v%v%v", node.TableExprs, node.Partitions, node.Where, node.OrderBy, node.Limit) } -func (node *Delete) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Comments, - node.Targets, - node.TableExprs, - node.Where, - node.OrderBy, - node.Limit, - ) -} - -// Set represents a SET statement. -type Set struct { - Comments Comments - Exprs SetExprs - Scope string -} - -// Set.Scope or Show.Scope -const ( - SessionStr = "session" - GlobalStr = "global" - VitessMetadataStr = "vitess_metadata" - ImplicitStr = "" -) - // Format formats the node. func (node *Set) Format(buf *TrackedBuffer) { if node.Scope == "" { @@ -677,26 +870,6 @@ func (node *Set) Format(buf *TrackedBuffer) { } } -func (node *Set) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Comments, - node.Exprs, - ) -} - -// DBDDL represents a CREATE, DROP, or ALTER database statement. -type DBDDL struct { - Action string - DBName string - IfExists bool - Collate string - Charset string -} - // Format formats the node. func (node *DBDDL) Format(buf *TrackedBuffer) { switch node.Action { @@ -711,61 +884,6 @@ func (node *DBDDL) Format(buf *TrackedBuffer) { } } -// walkSubtree walks the nodes of the subtree. -func (node *DBDDL) walkSubtree(visit Visit) error { - return nil -} - -// DDL represents a CREATE, ALTER, DROP, RENAME, TRUNCATE or ANALYZE statement. -type DDL struct { - Action string - - // FromTables is set if Action is RenameStr or DropStr. - FromTables TableNames - - // ToTables is set if Action is RenameStr. - ToTables TableNames - - // Table is set if Action is other than RenameStr or DropStr. - Table TableName - - // The following fields are set if a DDL was fully analyzed. - IfExists bool - TableSpec *TableSpec - OptLike *OptLike - PartitionSpec *PartitionSpec - - // VindexSpec is set for CreateVindexStr, DropVindexStr, AddColVindexStr, DropColVindexStr. - VindexSpec *VindexSpec - - // VindexCols is set for AddColVindexStr. - VindexCols []ColIdent - - // AutoIncSpec is set for AddAutoIncStr. - AutoIncSpec *AutoIncSpec -} - -// DDL strings. -const ( - CreateStr = "create" - AlterStr = "alter" - DropStr = "drop" - RenameStr = "rename" - TruncateStr = "truncate" - FlushStr = "flush" - CreateVindexStr = "create vindex" - DropVindexStr = "drop vindex" - AddVschemaTableStr = "add vschema table" - DropVschemaTableStr = "drop vschema table" - AddColVindexStr = "on table add vindex" - DropColVindexStr = "on table drop vindex" - AddSequenceStr = "add sequence" - AddAutoIncStr = "add auto_increment" - - // Vindex DDL param to specify the owner of a vindex - VindexOwnerStr = "owner" -) - // Format formats the node. func (node *DDL) Format(buf *TrackedBuffer) { switch node.Action { @@ -828,58 +946,11 @@ func (node *DDL) Format(buf *TrackedBuffer) { } } -func (node *DDL) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - for _, t := range node.AffectedTables() { - if err := Walk(visit, t); err != nil { - return err - } - } - return nil -} - -// AffectedTables returns the list table names affected by the DDL. -func (node *DDL) AffectedTables() TableNames { - if node.Action == RenameStr || node.Action == DropStr { - list := make(TableNames, 0, len(node.FromTables)+len(node.ToTables)) - list = append(list, node.FromTables...) - list = append(list, node.ToTables...) - return list - } - return TableNames{node.Table} -} - -// Partition strings -const ( - ReorganizeStr = "reorganize partition" -) - -// OptLike works for create table xxx like xxx -type OptLike struct { - LikeTable TableName -} - // Format formats the node. func (node *OptLike) Format(buf *TrackedBuffer) { buf.Myprintf("like %v", node.LikeTable) } -func (node *OptLike) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk(visit, node.LikeTable) -} - -// PartitionSpec describe partition actions (for alter and create) -type PartitionSpec struct { - Action string - Name ColIdent - Definitions []*PartitionDefinition -} - // Format formats the node. func (node *PartitionSpec) Format(buf *TrackedBuffer) { switch node.Action { @@ -896,28 +967,6 @@ func (node *PartitionSpec) Format(buf *TrackedBuffer) { } } -func (node *PartitionSpec) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - if err := Walk(visit, node.Name); err != nil { - return err - } - for _, def := range node.Definitions { - if err := Walk(visit, def); err != nil { - return err - } - } - return nil -} - -// PartitionDefinition describes a very minimal partition definition -type PartitionDefinition struct { - Name ColIdent - Limit Expr - Maxvalue bool -} - // Format formats the node func (node *PartitionDefinition) Format(buf *TrackedBuffer) { if !node.Maxvalue { @@ -927,25 +976,6 @@ func (node *PartitionDefinition) Format(buf *TrackedBuffer) { } } -func (node *PartitionDefinition) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Name, - node.Limit, - ) -} - -// TableSpec describes the structure of a table from a CREATE TABLE statement -type TableSpec struct { - Columns []*ColumnDefinition - Indexes []*IndexDefinition - Constraints []*ConstraintDefinition - Options string -} - // Format formats the node. func (ts *TableSpec) Format(buf *TrackedBuffer) { buf.Myprintf("(\n") @@ -966,99 +996,11 @@ func (ts *TableSpec) Format(buf *TrackedBuffer) { buf.Myprintf("\n)%s", strings.Replace(ts.Options, ", ", ",\n ", -1)) } -// AddColumn appends the given column to the list in the spec -func (ts *TableSpec) AddColumn(cd *ColumnDefinition) { - ts.Columns = append(ts.Columns, cd) -} - -// AddIndex appends the given index to the list in the spec -func (ts *TableSpec) AddIndex(id *IndexDefinition) { - ts.Indexes = append(ts.Indexes, id) -} - -// AddConstraint appends the given index to the list in the spec -func (ts *TableSpec) AddConstraint(cd *ConstraintDefinition) { - ts.Constraints = append(ts.Constraints, cd) -} - -func (ts *TableSpec) walkSubtree(visit Visit) error { - if ts == nil { - return nil - } - - for _, n := range ts.Columns { - if err := Walk(visit, n); err != nil { - return err - } - } - - for _, n := range ts.Indexes { - if err := Walk(visit, n); err != nil { - return err - } - } - - for _, n := range ts.Constraints { - if err := Walk(visit, n); err != nil { - return err - } - } - - return nil -} - -// ColumnDefinition describes a column in a CREATE TABLE statement -type ColumnDefinition struct { - Name ColIdent - Type ColumnType -} - // Format formats the node. func (col *ColumnDefinition) Format(buf *TrackedBuffer) { buf.Myprintf("%v %v", col.Name, &col.Type) } -func (col *ColumnDefinition) walkSubtree(visit Visit) error { - if col == nil { - return nil - } - return Walk( - visit, - col.Name, - &col.Type, - ) -} - -// ColumnType represents a sql type in a CREATE TABLE statement -// All optional fields are nil if not specified -type ColumnType struct { - // The base type string - Type string - - // Generic field options. - NotNull BoolVal - Autoincrement BoolVal - Default Expr - OnUpdate Expr - Comment *SQLVal - - // Numeric field options - Length *SQLVal - Unsigned BoolVal - Zerofill BoolVal - Scale *SQLVal - - // Text field options - Charset string - Collate string - - // Enum values - EnumValues []string - - // Key specification - KeyOpt ColumnKeyOption -} - // Format returns a canonical string representation of the type and all relevant options func (ct *ColumnType) Format(buf *TrackedBuffer) { buf.Myprintf("%s", ct.Type) @@ -1123,141 +1065,6 @@ func (ct *ColumnType) Format(buf *TrackedBuffer) { } } -// DescribeType returns the abbreviated type information as required for -// describe table -func (ct *ColumnType) DescribeType() string { - buf := NewTrackedBuffer(nil) - buf.Myprintf("%s", ct.Type) - if ct.Length != nil && ct.Scale != nil { - buf.Myprintf("(%v,%v)", ct.Length, ct.Scale) - } else if ct.Length != nil { - buf.Myprintf("(%v)", ct.Length) - } - - opts := make([]string, 0, 16) - if ct.Unsigned { - opts = append(opts, keywordStrings[UNSIGNED]) - } - if ct.Zerofill { - opts = append(opts, keywordStrings[ZEROFILL]) - } - if len(opts) != 0 { - buf.Myprintf(" %s", strings.Join(opts, " ")) - } - return buf.String() -} - -// SQLType returns the sqltypes type code for the given column -func (ct *ColumnType) SQLType() querypb.Type { - switch ct.Type { - case keywordStrings[TINYINT]: - if ct.Unsigned { - return sqltypes.Uint8 - } - return sqltypes.Int8 - case keywordStrings[SMALLINT]: - if ct.Unsigned { - return sqltypes.Uint16 - } - return sqltypes.Int16 - case keywordStrings[MEDIUMINT]: - if ct.Unsigned { - return sqltypes.Uint24 - } - return sqltypes.Int24 - case keywordStrings[INT]: - fallthrough - case keywordStrings[INTEGER]: - if ct.Unsigned { - return sqltypes.Uint32 - } - return sqltypes.Int32 - case keywordStrings[BIGINT]: - if ct.Unsigned { - return sqltypes.Uint64 - } - return sqltypes.Int64 - case keywordStrings[BOOL], keywordStrings[BOOLEAN]: - return sqltypes.Uint8 - case keywordStrings[TEXT]: - return sqltypes.Text - case keywordStrings[TINYTEXT]: - return sqltypes.Text - case keywordStrings[MEDIUMTEXT]: - return sqltypes.Text - case keywordStrings[LONGTEXT]: - return sqltypes.Text - case keywordStrings[BLOB]: - return sqltypes.Blob - case keywordStrings[TINYBLOB]: - return sqltypes.Blob - case keywordStrings[MEDIUMBLOB]: - return sqltypes.Blob - case keywordStrings[LONGBLOB]: - return sqltypes.Blob - case keywordStrings[CHAR]: - return sqltypes.Char - case keywordStrings[VARCHAR]: - return sqltypes.VarChar - case keywordStrings[BINARY]: - return sqltypes.Binary - case keywordStrings[VARBINARY]: - return sqltypes.VarBinary - case keywordStrings[DATE]: - return sqltypes.Date - case keywordStrings[TIME]: - return sqltypes.Time - case keywordStrings[DATETIME]: - return sqltypes.Datetime - case keywordStrings[TIMESTAMP]: - return sqltypes.Timestamp - case keywordStrings[YEAR]: - return sqltypes.Year - case keywordStrings[FLOAT_TYPE]: - return sqltypes.Float32 - case keywordStrings[DOUBLE]: - return sqltypes.Float64 - case keywordStrings[DECIMAL]: - return sqltypes.Decimal - case keywordStrings[BIT]: - return sqltypes.Bit - case keywordStrings[ENUM]: - return sqltypes.Enum - case keywordStrings[SET]: - return sqltypes.Set - case keywordStrings[JSON]: - return sqltypes.TypeJSON - case keywordStrings[GEOMETRY]: - return sqltypes.Geometry - case keywordStrings[POINT]: - return sqltypes.Geometry - case keywordStrings[LINESTRING]: - return sqltypes.Geometry - case keywordStrings[POLYGON]: - return sqltypes.Geometry - case keywordStrings[GEOMETRYCOLLECTION]: - return sqltypes.Geometry - case keywordStrings[MULTIPOINT]: - return sqltypes.Geometry - case keywordStrings[MULTILINESTRING]: - return sqltypes.Geometry - case keywordStrings[MULTIPOLYGON]: - return sqltypes.Geometry - } - panic("unimplemented type " + ct.Type) -} - -func (ct *ColumnType) walkSubtree(visit Visit) error { - return nil -} - -// IndexDefinition describes an index in a CREATE TABLE statement -type IndexDefinition struct { - Info *IndexInfo - Columns []*IndexColumn - Options []*IndexOption -} - // Format formats the node. func (idx *IndexDefinition) Format(buf *TrackedBuffer) { buf.Myprintf("%v (", idx.Info) @@ -1283,29 +1090,6 @@ func (idx *IndexDefinition) Format(buf *TrackedBuffer) { } } -func (idx *IndexDefinition) walkSubtree(visit Visit) error { - if idx == nil { - return nil - } - - for _, n := range idx.Columns { - if err := Walk(visit, n.Column); err != nil { - return err - } - } - - return nil -} - -// IndexInfo describes the name and type of an index in a CREATE TABLE statement -type IndexInfo struct { - Type string - Name ColIdent - Primary bool - Spatial bool - Unique bool -} - // Format formats the node. func (ii *IndexInfo) Format(buf *TrackedBuffer) { if ii.Primary { @@ -1318,82 +1102,12 @@ func (ii *IndexInfo) Format(buf *TrackedBuffer) { } } -func (ii *IndexInfo) walkSubtree(visit Visit) error { - return Walk(visit, ii.Name) -} - -// IndexColumn describes a column in an index definition with optional length -type IndexColumn struct { - Column ColIdent - Length *SQLVal -} - -// LengthScaleOption is used for types that have an optional length -// and scale -type LengthScaleOption struct { - Length *SQLVal - Scale *SQLVal -} - -// IndexOption is used for trailing options for indexes: COMMENT, KEY_BLOCK_SIZE, USING -type IndexOption struct { - Name string - Value *SQLVal - Using string -} - -// ColumnKeyOption indicates whether or not the given column is defined as an -// index element and contains the type of the option -type ColumnKeyOption int - -const ( - colKeyNone ColumnKeyOption = iota - colKeyPrimary - colKeySpatialKey - colKeyUnique - colKeyUniqueKey - colKey -) - -// VindexSpec defines a vindex for a CREATE VINDEX or DROP VINDEX statement -type VindexSpec struct { - Name ColIdent - Type ColIdent - Params []VindexParam -} - -// AutoIncSpec defines and autoincrement value for a ADD AUTO_INCREMENT statement -type AutoIncSpec struct { - Column ColIdent - Sequence TableName -} - // Format formats the node. func (node *AutoIncSpec) Format(buf *TrackedBuffer) { buf.Myprintf("%v ", node.Column) buf.Myprintf("using %v", node.Sequence) } -func (node *AutoIncSpec) walkSubtree(visit Visit) error { - err := Walk(visit, node.Sequence, node.Column) - return err -} - -// ParseParams parses the vindex parameter list, pulling out the special-case -// "owner" parameter -func (node *VindexSpec) ParseParams() (string, map[string]string) { - var owner string - params := map[string]string{} - for _, p := range node.Params { - if p.Key.Lowered() == VindexOwnerStr { - owner = p.Val - } else { - params[p.Key.String()] = p.Val - } - } - return owner, params -} - // Format formats the node. The "CREATE VINDEX" preamble was formatted in // the containing DDL node Format, so this just prints the type, any // parameters, and optionally the owner @@ -1412,83 +1126,18 @@ func (node *VindexSpec) Format(buf *TrackedBuffer) { } } -func (node *VindexSpec) walkSubtree(visit Visit) error { - err := Walk(visit, - node.Name, - ) - - if err != nil { - return err - } - - for _, p := range node.Params { - err := Walk(visit, p) - - if err != nil { - return err - } - } - return nil -} - -// VindexParam defines a key/value parameter for a CREATE VINDEX statement -type VindexParam struct { - Key ColIdent - Val string -} - // Format formats the node. func (node VindexParam) Format(buf *TrackedBuffer) { buf.Myprintf("%s=%s", node.Key.String(), node.Val) } -func (node VindexParam) walkSubtree(visit Visit) error { - return Walk(visit, - node.Key, - ) -} - -// ConstraintDefinition describes a constraint in a CREATE TABLE statement -type ConstraintDefinition struct { - Name string - Details ConstraintInfo -} - -// ConstraintInfo details a constraint in a CREATE TABLE statement -type ConstraintInfo interface { - SQLNode - constraintInfo() -} - // Format formats the node. func (c *ConstraintDefinition) Format(buf *TrackedBuffer) { if c.Name != "" { buf.Myprintf("constraint %s ", c.Name) } - c.Details.Format(buf) -} - -func (c *ConstraintDefinition) walkSubtree(visit Visit) error { - return Walk(visit, c.Details) -} - -// ReferenceAction indicates the action takes by a referential constraint e.g. -// the `CASCADE` in a `FOREIGN KEY .. ON DELETE CASCADE` table definition. -type ReferenceAction int - -// These map to the SQL-defined reference actions. -// See https://dev.mysql.com/doc/refman/8.0/en/create-table-foreign-keys.html#foreign-keys-referential-actions -const ( - // DefaultAction indicates no action was explicitly specified. - DefaultAction ReferenceAction = iota - Restrict - Cascade - NoAction - SetNull - SetDefault -) - -func (a ReferenceAction) walkSubtree(visit Visit) error { return nil } + c.Details.Format(buf) +} // Format formats the node. func (a ReferenceAction) Format(buf *TrackedBuffer) { @@ -1506,17 +1155,6 @@ func (a ReferenceAction) Format(buf *TrackedBuffer) { } } -// ForeignKeyDefinition describes a foreign key in a CREATE TABLE statement -type ForeignKeyDefinition struct { - Source Columns - ReferencedTable TableName - ReferencedColumns Columns - OnDelete ReferenceAction - OnUpdate ReferenceAction -} - -var _ ConstraintInfo = &ForeignKeyDefinition{} - // Format formats the node. func (f *ForeignKeyDefinition) Format(buf *TrackedBuffer) { buf.Myprintf("foreign key %v references %v %v", f.Source, f.ReferencedTable, f.ReferencedColumns) @@ -1528,28 +1166,6 @@ func (f *ForeignKeyDefinition) Format(buf *TrackedBuffer) { } } -func (f *ForeignKeyDefinition) constraintInfo() {} - -func (f *ForeignKeyDefinition) walkSubtree(visit Visit) error { - if err := Walk(visit, f.Source); err != nil { - return err - } - if err := Walk(visit, f.ReferencedTable); err != nil { - return err - } - return Walk(visit, f.ReferencedColumns) -} - -// Show represents a show statement. -type Show struct { - Type string - OnTable TableName - Table TableName - ShowTablesOpt *ShowTablesOpt - Scope string - ShowCollationFilterOpt *Expr -} - // Format formats the node. func (node *Show) Format(buf *TrackedBuffer) { if (node.Type == "tables" || node.Type == "columns" || node.Type == "fields") && node.ShowTablesOpt != nil { @@ -1580,34 +1196,6 @@ func (node *Show) Format(buf *TrackedBuffer) { } } -// HasOnTable returns true if the show statement has an "on" clause -func (node *Show) HasOnTable() bool { - return node.OnTable.Name.v != "" -} - -// HasTable returns true if the show statement has a parsed table name. -// Not all show statements parse table names. -func (node *Show) HasTable() bool { - return node.Table.Name.v != "" -} - -func (node *Show) walkSubtree(visit Visit) error { - return nil -} - -// ShowTablesOpt is show tables option -type ShowTablesOpt struct { - Full string - DbName string - Filter *ShowFilter -} - -// ShowFilter is show tables filter -type ShowFilter struct { - Like string - Filter Expr -} - // Format formats the node. func (node *ShowFilter) Format(buf *TrackedBuffer) { if node == nil { @@ -1620,15 +1208,6 @@ func (node *ShowFilter) Format(buf *TrackedBuffer) { } } -func (node *ShowFilter) walkSubtree(visit Visit) error { - return nil -} - -// Use represents a use statement. -type Use struct { - DBName TableIdent -} - // Format formats the node. func (node *Use) Format(buf *TrackedBuffer) { if node.DBName.v != "" { @@ -1638,78 +1217,31 @@ func (node *Use) Format(buf *TrackedBuffer) { } } -func (node *Use) walkSubtree(visit Visit) error { - return Walk(visit, node.DBName) -} - -// Begin represents a Begin statement. -type Begin struct{} - -// Format formats the node. -func (node *Begin) Format(buf *TrackedBuffer) { - buf.WriteString("begin") -} - -func (node *Begin) walkSubtree(visit Visit) error { - return nil -} - -// Commit represents a Commit statement. -type Commit struct{} - // Format formats the node. func (node *Commit) Format(buf *TrackedBuffer) { buf.WriteString("commit") } -func (node *Commit) walkSubtree(visit Visit) error { - return nil +// Format formats the node. +func (node *Begin) Format(buf *TrackedBuffer) { + buf.WriteString("begin") } -// Rollback represents a Rollback statement. -type Rollback struct{} - // Format formats the node. func (node *Rollback) Format(buf *TrackedBuffer) { buf.WriteString("rollback") } -func (node *Rollback) walkSubtree(visit Visit) error { - return nil -} - -// OtherRead represents a DESCRIBE, or EXPLAIN statement. -// It should be used only as an indicator. It does not contain -// the full AST for the statement. -type OtherRead struct{} - // Format formats the node. func (node *OtherRead) Format(buf *TrackedBuffer) { buf.WriteString("otherread") } -func (node *OtherRead) walkSubtree(visit Visit) error { - return nil -} - -// OtherAdmin represents a misc statement that relies on ADMIN privileges, -// such as REPAIR, OPTIMIZE, or TRUNCATE statement. -// It should be used only as an indicator. It does not contain -// the full AST for the statement. -type OtherAdmin struct{} - // Format formats the node. func (node *OtherAdmin) Format(buf *TrackedBuffer) { buf.WriteString("otheradmin") } -func (node *OtherAdmin) walkSubtree(visit Visit) error { - return nil -} - -// Comments represents a list of comments. -type Comments [][]byte - // Format formats the node. func (node Comments) Format(buf *TrackedBuffer) { for _, c := range node { @@ -1717,13 +1249,6 @@ func (node Comments) Format(buf *TrackedBuffer) { } } -func (node Comments) walkSubtree(visit Visit) error { - return nil -} - -// SelectExprs represents SELECT expressions. -type SelectExprs []SelectExpr - // Format formats the node. func (node SelectExprs) Format(buf *TrackedBuffer) { var prefix string @@ -1733,30 +1258,6 @@ func (node SelectExprs) Format(buf *TrackedBuffer) { } } -func (node SelectExprs) walkSubtree(visit Visit) error { - for _, n := range node { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// SelectExpr represents a SELECT expression. -type SelectExpr interface { - iSelectExpr() - SQLNode -} - -func (*StarExpr) iSelectExpr() {} -func (*AliasedExpr) iSelectExpr() {} -func (Nextval) iSelectExpr() {} - -// StarExpr defines a '*' or 'table.*' expression. -type StarExpr struct { - TableName TableName -} - // Format formats the node. func (node *StarExpr) Format(buf *TrackedBuffer) { if !node.TableName.IsEmpty() { @@ -1765,22 +1266,6 @@ func (node *StarExpr) Format(buf *TrackedBuffer) { buf.Myprintf("*") } -func (node *StarExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.TableName, - ) -} - -// AliasedExpr defines an aliased SELECT expression. -type AliasedExpr struct { - Expr Expr - As ColIdent -} - // Format formats the node. func (node *AliasedExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%v", node.Expr) @@ -1789,34 +1274,11 @@ func (node *AliasedExpr) Format(buf *TrackedBuffer) { } } -func (node *AliasedExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - node.As, - ) -} - -// Nextval defines the NEXT VALUE expression. -type Nextval struct { - Expr Expr -} - // Format formats the node. func (node Nextval) Format(buf *TrackedBuffer) { buf.Myprintf("next %v values", node.Expr) } -func (node Nextval) walkSubtree(visit Visit) error { - return Walk(visit, node.Expr) -} - -// Columns represents an insert column list. -type Columns []ColIdent - // Format formats the node. func (node Columns) Format(buf *TrackedBuffer) { if node == nil { @@ -1830,29 +1292,6 @@ func (node Columns) Format(buf *TrackedBuffer) { buf.WriteString(")") } -func (node Columns) walkSubtree(visit Visit) error { - for _, n := range node { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// FindColumn finds a column in the column list, returning -// the index if it exists or -1 otherwise -func (node Columns) FindColumn(col ColIdent) int { - for i, colName := range node { - if colName.Equal(col) { - return i - } - } - return -1 -} - -// Partitions is a type alias for Columns so we can handle printing efficiently -type Partitions Columns - // Format formats the node func (node Partitions) Format(buf *TrackedBuffer) { if node == nil { @@ -1866,18 +1305,6 @@ func (node Partitions) Format(buf *TrackedBuffer) { buf.WriteString(")") } -func (node Partitions) walkSubtree(visit Visit) error { - for _, n := range node { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// TableExprs represents a list of table expressions. -type TableExprs []TableExpr - // Format formats the node. func (node TableExprs) Format(buf *TrackedBuffer) { var prefix string @@ -1887,35 +1314,6 @@ func (node TableExprs) Format(buf *TrackedBuffer) { } } -func (node TableExprs) walkSubtree(visit Visit) error { - for _, n := range node { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// TableExpr represents a table expression. -type TableExpr interface { - iTableExpr() - SQLNode -} - -func (*AliasedTableExpr) iTableExpr() {} -func (*ParenTableExpr) iTableExpr() {} -func (*JoinTableExpr) iTableExpr() {} - -// AliasedTableExpr represents a table expression -// coupled with an optional alias or index hint. -// If As is empty, no alias was used. -type AliasedTableExpr struct { - Expr SimpleTableExpr - Partitions Partitions - As TableIdent - Hints *IndexHints -} - // Format formats the node. func (node *AliasedTableExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%v%v", node.Expr, node.Partitions) @@ -1928,37 +1326,6 @@ func (node *AliasedTableExpr) Format(buf *TrackedBuffer) { } } -func (node *AliasedTableExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - node.As, - node.Hints, - ) -} - -// RemoveHints returns a new AliasedTableExpr with the hints removed. -func (node *AliasedTableExpr) RemoveHints() *AliasedTableExpr { - noHints := *node - noHints.Hints = nil - return &noHints -} - -// SimpleTableExpr represents a simple table expression. -type SimpleTableExpr interface { - iSimpleTableExpr() - SQLNode -} - -func (TableName) iSimpleTableExpr() {} -func (*Subquery) iSimpleTableExpr() {} - -// TableNames is a list of TableName. -type TableNames []TableName - // Format formats the node. func (node TableNames) Format(buf *TrackedBuffer) { var prefix string @@ -1968,24 +1335,6 @@ func (node TableNames) Format(buf *TrackedBuffer) { } } -func (node TableNames) walkSubtree(visit Visit) error { - for _, n := range node { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// TableName represents a table name. -// Qualifier, if specified, represents a database or keyspace. -// TableName is a value struct whose fields are case sensitive. -// This means two TableName vars can be compared for equality -// and a TableName can also be used as key in a map. -type TableName struct { - Name, Qualifier TableIdent -} - // Format formats the node. func (node TableName) Format(buf *TrackedBuffer) { if node.IsEmpty() { @@ -1997,57 +1346,11 @@ func (node TableName) Format(buf *TrackedBuffer) { buf.Myprintf("%v", node.Name) } -func (node TableName) walkSubtree(visit Visit) error { - return Walk( - visit, - node.Name, - node.Qualifier, - ) -} - -// IsEmpty returns true if TableName is nil or empty. -func (node TableName) IsEmpty() bool { - // If Name is empty, Qualifier is also empty. - return node.Name.IsEmpty() -} - -// ToViewName returns a TableName acceptable for use as a VIEW. VIEW names are -// always lowercase, so ToViewName lowercasese the name. Databases are case-sensitive -// so Qualifier is left untouched. -func (node TableName) ToViewName() TableName { - return TableName{ - Qualifier: node.Qualifier, - Name: NewTableIdent(strings.ToLower(node.Name.v)), - } -} - -// ParenTableExpr represents a parenthesized list of TableExpr. -type ParenTableExpr struct { - Exprs TableExprs -} - // Format formats the node. func (node *ParenTableExpr) Format(buf *TrackedBuffer) { buf.Myprintf("(%v)", node.Exprs) } -func (node *ParenTableExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Exprs, - ) -} - -// JoinCondition represents the join conditions (either a ON or USING clause) -// of a JoinTableExpr. -type JoinCondition struct { - On Expr - Using Columns -} - // Format formats the node. func (node JoinCondition) Format(buf *TrackedBuffer) { if node.On != nil { @@ -2058,63 +1361,11 @@ func (node JoinCondition) Format(buf *TrackedBuffer) { } } -func (node JoinCondition) walkSubtree(visit Visit) error { - return Walk( - visit, - node.On, - node.Using, - ) -} - -// JoinTableExpr represents a TableExpr that's a JOIN operation. -type JoinTableExpr struct { - LeftExpr TableExpr - Join string - RightExpr TableExpr - Condition JoinCondition -} - -// JoinTableExpr.Join -const ( - JoinStr = "join" - StraightJoinStr = "straight_join" - LeftJoinStr = "left join" - RightJoinStr = "right join" - NaturalJoinStr = "natural join" - NaturalLeftJoinStr = "natural left join" - NaturalRightJoinStr = "natural right join" -) - // Format formats the node. func (node *JoinTableExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%v %s %v%v", node.LeftExpr, node.Join, node.RightExpr, node.Condition) } -func (node *JoinTableExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.LeftExpr, - node.RightExpr, - node.Condition, - ) -} - -// IndexHints represents a list of index hints. -type IndexHints struct { - Type string - Indexes []ColIdent -} - -// Index hints. -const ( - UseStr = "use " - IgnoreStr = "ignore " - ForceStr = "force " -) - // Format formats the node. func (node *IndexHints) Format(buf *TrackedBuffer) { buf.Myprintf(" %sindex ", node.Type) @@ -2123,134 +1374,17 @@ func (node *IndexHints) Format(buf *TrackedBuffer) { buf.Myprintf("%s%v", prefix, n) prefix = ", " } - buf.Myprintf(")") -} - -func (node *IndexHints) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - for _, n := range node.Indexes { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// Where represents a WHERE or HAVING clause. -type Where struct { - Type string - Expr Expr -} - -// Where.Type -const ( - WhereStr = "where" - HavingStr = "having" -) - -// NewWhere creates a WHERE or HAVING clause out -// of a Expr. If the expression is nil, it returns nil. -func NewWhere(typ string, expr Expr) *Where { - if expr == nil { - return nil - } - return &Where{Type: typ, Expr: expr} -} - -// Format formats the node. -func (node *Where) Format(buf *TrackedBuffer) { - if node == nil || node.Expr == nil { - return - } - buf.Myprintf(" %s %v", node.Type, node.Expr) -} - -func (node *Where) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - ) -} - -// Expr represents an expression. -type Expr interface { - iExpr() - // replace replaces any subexpression that matches - // from with to. The implementation can use the - // replaceExprs convenience function. - replace(from, to Expr) bool - SQLNode -} - -func (*AndExpr) iExpr() {} -func (*OrExpr) iExpr() {} -func (*NotExpr) iExpr() {} -func (*ParenExpr) iExpr() {} -func (*ComparisonExpr) iExpr() {} -func (*RangeCond) iExpr() {} -func (*IsExpr) iExpr() {} -func (*ExistsExpr) iExpr() {} -func (*SQLVal) iExpr() {} -func (*NullVal) iExpr() {} -func (BoolVal) iExpr() {} -func (*ColName) iExpr() {} -func (ValTuple) iExpr() {} -func (*Subquery) iExpr() {} -func (ListArg) iExpr() {} -func (*BinaryExpr) iExpr() {} -func (*UnaryExpr) iExpr() {} -func (*IntervalExpr) iExpr() {} -func (*CollateExpr) iExpr() {} -func (*FuncExpr) iExpr() {} -func (*TimestampFuncExpr) iExpr() {} -func (*CurTimeFuncExpr) iExpr() {} -func (*CaseExpr) iExpr() {} -func (*ValuesFuncExpr) iExpr() {} -func (*ConvertExpr) iExpr() {} -func (*SubstrExpr) iExpr() {} -func (*ConvertUsingExpr) iExpr() {} -func (*MatchExpr) iExpr() {} -func (*GroupConcatExpr) iExpr() {} -func (*Default) iExpr() {} - -// ReplaceExpr finds the from expression from root -// and replaces it with to. If from matches root, -// then to is returned. -func ReplaceExpr(root, from, to Expr) Expr { - if root == from { - return to - } - root.replace(from, to) - return root + buf.Myprintf(")") } -// replaceExprs is a convenience function used by implementors -// of the replace method. -func replaceExprs(from, to Expr, exprs ...*Expr) bool { - for _, expr := range exprs { - if *expr == nil { - continue - } - if *expr == from { - *expr = to - return true - } - if (*expr).replace(from, to) { - return true - } +// Format formats the node. +func (node *Where) Format(buf *TrackedBuffer) { + if node == nil || node.Expr == nil { + return } - return false + buf.Myprintf(" %s %v", node.Type, node.Expr) } -// Exprs represents a list of value expressions. -// It's not a valid expression because it's not parenthesized. -type Exprs []Expr - // Format formats the node. func (node Exprs) Format(buf *TrackedBuffer) { var prefix string @@ -2260,139 +1394,26 @@ func (node Exprs) Format(buf *TrackedBuffer) { } } -func (node Exprs) walkSubtree(visit Visit) error { - for _, n := range node { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// AndExpr represents an AND expression. -type AndExpr struct { - Left, Right Expr -} - // Format formats the node. func (node *AndExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%v and %v", node.Left, node.Right) } -func (node *AndExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Left, - node.Right, - ) -} - -func (node *AndExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Left, &node.Right) -} - -// OrExpr represents an OR expression. -type OrExpr struct { - Left, Right Expr -} - // Format formats the node. func (node *OrExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%v or %v", node.Left, node.Right) } -func (node *OrExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Left, - node.Right, - ) -} - -func (node *OrExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Left, &node.Right) -} - -// NotExpr represents a NOT expression. -type NotExpr struct { - Expr Expr -} - // Format formats the node. func (node *NotExpr) Format(buf *TrackedBuffer) { buf.Myprintf("not %v", node.Expr) } -func (node *NotExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - ) -} - -func (node *NotExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Expr) -} - -// ParenExpr represents a parenthesized boolean expression. -type ParenExpr struct { - Expr Expr -} - // Format formats the node. func (node *ParenExpr) Format(buf *TrackedBuffer) { buf.Myprintf("(%v)", node.Expr) } -func (node *ParenExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - ) -} - -func (node *ParenExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Expr) -} - -// ComparisonExpr represents a two-value comparison expression. -type ComparisonExpr struct { - Operator string - Left, Right Expr - Escape Expr -} - -// ComparisonExpr.Operator -const ( - EqualStr = "=" - LessThanStr = "<" - GreaterThanStr = ">" - LessEqualStr = "<=" - GreaterEqualStr = ">=" - NotEqualStr = "!=" - NullSafeEqualStr = "<=>" - InStr = "in" - NotInStr = "not in" - LikeStr = "like" - NotLikeStr = "not like" - RegexpStr = "regexp" - NotRegexpStr = "not regexp" - JSONExtractOp = "->" - JSONUnquoteExtractOp = "->>" -) - // Format formats the node. func (node *ComparisonExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%v %s %v", node.Left, node.Operator, node.Right) @@ -2401,218 +1422,21 @@ func (node *ComparisonExpr) Format(buf *TrackedBuffer) { } } -func (node *ComparisonExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Left, - node.Right, - node.Escape, - ) -} - -func (node *ComparisonExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Left, &node.Right, &node.Escape) -} - -// IsImpossible returns true if the comparison in the expression can never evaluate to true. -// Note that this is not currently exhaustive to ALL impossible comparisons. -func (node *ComparisonExpr) IsImpossible() bool { - var left, right *SQLVal - var ok bool - if left, ok = node.Left.(*SQLVal); !ok { - return false - } - if right, ok = node.Right.(*SQLVal); !ok { - return false - } - if node.Operator == NotEqualStr && left.Type == right.Type { - if len(left.Val) != len(right.Val) { - return false - } - - for i := range left.Val { - if left.Val[i] != right.Val[i] { - return false - } - } - return true - } - return false -} - -// RangeCond represents a BETWEEN or a NOT BETWEEN expression. -type RangeCond struct { - Operator string - Left Expr - From, To Expr -} - -// RangeCond.Operator -const ( - BetweenStr = "between" - NotBetweenStr = "not between" -) - // Format formats the node. func (node *RangeCond) Format(buf *TrackedBuffer) { buf.Myprintf("%v %s %v and %v", node.Left, node.Operator, node.From, node.To) } -func (node *RangeCond) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Left, - node.From, - node.To, - ) -} - -func (node *RangeCond) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Left, &node.From, &node.To) -} - -// IsExpr represents an IS ... or an IS NOT ... expression. -type IsExpr struct { - Operator string - Expr Expr -} - -// IsExpr.Operator -const ( - IsNullStr = "is null" - IsNotNullStr = "is not null" - IsTrueStr = "is true" - IsNotTrueStr = "is not true" - IsFalseStr = "is false" - IsNotFalseStr = "is not false" -) - // Format formats the node. func (node *IsExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%v %s", node.Expr, node.Operator) } -func (node *IsExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - ) -} - -func (node *IsExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Expr) -} - -// ExistsExpr represents an EXISTS expression. -type ExistsExpr struct { - Subquery *Subquery -} - // Format formats the node. func (node *ExistsExpr) Format(buf *TrackedBuffer) { buf.Myprintf("exists %v", node.Subquery) } -func (node *ExistsExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Subquery, - ) -} - -func (node *ExistsExpr) replace(from, to Expr) bool { - return false -} - -// ExprFromValue converts the given Value into an Expr or returns an error. -func ExprFromValue(value sqltypes.Value) (Expr, error) { - // The type checks here follow the rules defined in sqltypes/types.go. - switch { - case value.Type() == sqltypes.Null: - return &NullVal{}, nil - case value.IsIntegral(): - return NewIntVal(value.ToBytes()), nil - case value.IsFloat() || value.Type() == sqltypes.Decimal: - return NewFloatVal(value.ToBytes()), nil - case value.IsQuoted(): - return NewStrVal(value.ToBytes()), nil - default: - // We cannot support sqltypes.Expression, or any other invalid type. - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot convert value %v to AST", value) - } -} - -// ValType specifies the type for SQLVal. -type ValType int - -// These are the possible Valtype values. -// HexNum represents a 0x... value. It cannot -// be treated as a simple value because it can -// be interpreted differently depending on the -// context. -const ( - StrVal = ValType(iota) - IntVal - FloatVal - HexNum - HexVal - ValArg - BitVal -) - -// SQLVal represents a single value. -type SQLVal struct { - Type ValType - Val []byte -} - -// NewStrVal builds a new StrVal. -func NewStrVal(in []byte) *SQLVal { - return &SQLVal{Type: StrVal, Val: in} -} - -// NewIntVal builds a new IntVal. -func NewIntVal(in []byte) *SQLVal { - return &SQLVal{Type: IntVal, Val: in} -} - -// NewFloatVal builds a new FloatVal. -func NewFloatVal(in []byte) *SQLVal { - return &SQLVal{Type: FloatVal, Val: in} -} - -// NewHexNum builds a new HexNum. -func NewHexNum(in []byte) *SQLVal { - return &SQLVal{Type: HexNum, Val: in} -} - -// NewHexVal builds a new HexVal. -func NewHexVal(in []byte) *SQLVal { - return &SQLVal{Type: HexVal, Val: in} -} - -// NewBitVal builds a new BitVal containing a bit literal. -func NewBitVal(in []byte) *SQLVal { - return &SQLVal{Type: BitVal, Val: in} -} - -// NewValArg builds a new ValArg. -func NewValArg(in []byte) *SQLVal { - return &SQLVal{Type: ValArg, Val: in} -} - // Format formats the node. func (node *SQLVal) Format(buf *TrackedBuffer) { switch node.Type { @@ -2631,43 +1455,11 @@ func (node *SQLVal) Format(buf *TrackedBuffer) { } } -func (node *SQLVal) walkSubtree(visit Visit) error { - return nil -} - -func (node *SQLVal) replace(from, to Expr) bool { - return false -} - -// HexDecode decodes the hexval into bytes. -func (node *SQLVal) HexDecode() ([]byte, error) { - dst := make([]byte, hex.DecodedLen(len([]byte(node.Val)))) - _, err := hex.Decode(dst, []byte(node.Val)) - if err != nil { - return nil, err - } - return dst, err -} - -// NullVal represents a NULL value. -type NullVal struct{} - // Format formats the node. func (node *NullVal) Format(buf *TrackedBuffer) { buf.Myprintf("null") } -func (node *NullVal) walkSubtree(visit Visit) error { - return nil -} - -func (node *NullVal) replace(from, to Expr) bool { - return false -} - -// BoolVal is true or false. -type BoolVal bool - // Format formats the node. func (node BoolVal) Format(buf *TrackedBuffer) { if node { @@ -2677,25 +1469,6 @@ func (node BoolVal) Format(buf *TrackedBuffer) { } } -func (node BoolVal) walkSubtree(visit Visit) error { - return nil -} - -func (node BoolVal) replace(from, to Expr) bool { - return false -} - -// ColName represents a column name. -type ColName struct { - // Metadata is not populated by the parser. - // It's a placeholder for analyzers to store - // additional data, typically info about which - // table or column this node references. - Metadata interface{} - Name ColIdent - Qualifier TableName -} - // Format formats the node. func (node *ColName) Format(buf *TrackedBuffer) { if !node.Qualifier.IsEmpty() { @@ -2704,160 +1477,26 @@ func (node *ColName) Format(buf *TrackedBuffer) { buf.Myprintf("%v", node.Name) } -func (node *ColName) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Name, - node.Qualifier, - ) -} - -func (node *ColName) replace(from, to Expr) bool { - return false -} - -// Equal returns true if the column names match. -func (node *ColName) Equal(c *ColName) bool { - // Failsafe: ColName should not be empty. - if node == nil || c == nil { - return false - } - return node.Name.Equal(c.Name) && node.Qualifier == c.Qualifier -} - -// ColTuple represents a list of column values. -// It can be ValTuple, Subquery, ListArg. -type ColTuple interface { - iColTuple() - Expr -} - -func (ValTuple) iColTuple() {} -func (*Subquery) iColTuple() {} -func (ListArg) iColTuple() {} - -// ValTuple represents a tuple of actual values. -type ValTuple Exprs - // Format formats the node. func (node ValTuple) Format(buf *TrackedBuffer) { buf.Myprintf("(%v)", Exprs(node)) } -func (node ValTuple) walkSubtree(visit Visit) error { - return Walk(visit, Exprs(node)) -} - -func (node ValTuple) replace(from, to Expr) bool { - for i := range node { - if replaceExprs(from, to, &node[i]) { - return true - } - } - return false -} - -// Subquery represents a subquery. -type Subquery struct { - Select SelectStatement -} - // Format formats the node. func (node *Subquery) Format(buf *TrackedBuffer) { buf.Myprintf("(%v)", node.Select) } -func (node *Subquery) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Select, - ) -} - -func (node *Subquery) replace(from, to Expr) bool { - return false -} - -// ListArg represents a named list argument. -type ListArg []byte - // Format formats the node. func (node ListArg) Format(buf *TrackedBuffer) { buf.WriteArg(string(node)) } -func (node ListArg) walkSubtree(visit Visit) error { - return nil -} - -func (node ListArg) replace(from, to Expr) bool { - return false -} - -// BinaryExpr represents a binary value expression. -type BinaryExpr struct { - Operator string - Left, Right Expr -} - -// BinaryExpr.Operator -const ( - BitAndStr = "&" - BitOrStr = "|" - BitXorStr = "^" - PlusStr = "+" - MinusStr = "-" - MultStr = "*" - DivStr = "/" - IntDivStr = "div" - ModStr = "%" - ShiftLeftStr = "<<" - ShiftRightStr = ">>" -) - // Format formats the node. func (node *BinaryExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%v %s %v", node.Left, node.Operator, node.Right) } -func (node *BinaryExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Left, - node.Right, - ) -} - -func (node *BinaryExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Left, &node.Right) -} - -// UnaryExpr represents a unary value expression. -type UnaryExpr struct { - Operator string - Expr Expr -} - -// UnaryExpr.Operator -const ( - UPlusStr = "+" - UMinusStr = "-" - TildaStr = "~" - BangStr = "!" - BinaryStr = "binary " - UBinaryStr = "_binary " - Utf8mb4Str = "_utf8mb4 " -) - // Format formats the node. func (node *UnaryExpr) Format(buf *TrackedBuffer) { if _, unary := node.Expr.(*UnaryExpr); unary { @@ -2867,248 +1506,44 @@ func (node *UnaryExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%s%v", node.Operator, node.Expr) } -func (node *UnaryExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - ) -} - -func (node *UnaryExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Expr) -} - -// IntervalExpr represents a date-time INTERVAL expression. -type IntervalExpr struct { - Expr Expr - Unit string -} - // Format formats the node. func (node *IntervalExpr) Format(buf *TrackedBuffer) { buf.Myprintf("interval %v %s", node.Expr, node.Unit) } -func (node *IntervalExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - ) -} - -func (node *IntervalExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Expr) -} - -// TimestampFuncExpr represents the function and arguments for TIMESTAMP{ADD,DIFF} functions. -type TimestampFuncExpr struct { - Name string - Expr1 Expr - Expr2 Expr - Unit string -} - // Format formats the node. func (node *TimestampFuncExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%s(%s, %v, %v)", node.Name, node.Unit, node.Expr1, node.Expr2) } -func (node *TimestampFuncExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr1, - node.Expr2, - ) -} - -func (node *TimestampFuncExpr) replace(from, to Expr) bool { - if replaceExprs(from, to, &node.Expr1) { - return true - } - if replaceExprs(from, to, &node.Expr2) { - return true - } - return false -} - -// CurTimeFuncExpr represents the function and arguments for CURRENT DATE/TIME functions -// supported functions are documented in the grammar -type CurTimeFuncExpr struct { - Name ColIdent - Fsp Expr // fractional seconds precision, integer from 0 to 6 -} - // Format formats the node. func (node *CurTimeFuncExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%s(%v)", node.Name.String(), node.Fsp) } -func (node *CurTimeFuncExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Fsp, - ) -} - -func (node *CurTimeFuncExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Fsp) -} - -// CollateExpr represents dynamic collate operator. -type CollateExpr struct { - Expr Expr - Charset string -} - // Format formats the node. func (node *CollateExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%v collate %s", node.Expr, node.Charset) } -func (node *CollateExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - ) -} - -func (node *CollateExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Expr) -} - -// FuncExpr represents a function call. -type FuncExpr struct { - Qualifier TableIdent - Name ColIdent - Distinct bool - Exprs SelectExprs -} - // Format formats the node. func (node *FuncExpr) Format(buf *TrackedBuffer) { var distinct string if node.Distinct { distinct = "distinct " } - if !node.Qualifier.IsEmpty() { - buf.Myprintf("%v.", node.Qualifier) - } - // Function names should not be back-quoted even - // if they match a reserved word. So, print the - // name as is. - buf.Myprintf("%s(%s%v)", node.Name.String(), distinct, node.Exprs) -} - -func (node *FuncExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Qualifier, - node.Name, - node.Exprs, - ) -} - -func (node *FuncExpr) replace(from, to Expr) bool { - for _, sel := range node.Exprs { - aliased, ok := sel.(*AliasedExpr) - if !ok { - continue - } - if replaceExprs(from, to, &aliased.Expr) { - return true - } - } - return false -} - -// Aggregates is a map of all aggregate functions. -var Aggregates = map[string]bool{ - "avg": true, - "bit_and": true, - "bit_or": true, - "bit_xor": true, - "count": true, - "group_concat": true, - "max": true, - "min": true, - "std": true, - "stddev_pop": true, - "stddev_samp": true, - "stddev": true, - "sum": true, - "var_pop": true, - "var_samp": true, - "variance": true, -} - -// IsAggregate returns true if the function is an aggregate. -func (node *FuncExpr) IsAggregate() bool { - return Aggregates[node.Name.Lowered()] -} - -// GroupConcatExpr represents a call to GROUP_CONCAT -type GroupConcatExpr struct { - Distinct string - Exprs SelectExprs - OrderBy OrderBy - Separator string -} - -// Format formats the node -func (node *GroupConcatExpr) Format(buf *TrackedBuffer) { - buf.Myprintf("group_concat(%s%v%v%s)", node.Distinct, node.Exprs, node.OrderBy, node.Separator) -} - -func (node *GroupConcatExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Exprs, - node.OrderBy, - ) -} - -func (node *GroupConcatExpr) replace(from, to Expr) bool { - for _, sel := range node.Exprs { - aliased, ok := sel.(*AliasedExpr) - if !ok { - continue - } - if replaceExprs(from, to, &aliased.Expr) { - return true - } - } - for _, order := range node.OrderBy { - if replaceExprs(from, to, &order.Expr) { - return true - } + if !node.Qualifier.IsEmpty() { + buf.Myprintf("%v.", node.Qualifier) } - return false + // Function names should not be back-quoted even + // if they match a reserved word. So, print the + // name as is. + buf.Myprintf("%s(%s%v)", node.Name.String(), distinct, node.Exprs) } -// ValuesFuncExpr represents a function call. -type ValuesFuncExpr struct { - Name *ColName +// Format formats the node +func (node *GroupConcatExpr) Format(buf *TrackedBuffer) { + buf.Myprintf("group_concat(%s%v%v%s)", node.Distinct, node.Exprs, node.OrderBy, node.Separator) } // Format formats the node. @@ -3116,32 +1551,6 @@ func (node *ValuesFuncExpr) Format(buf *TrackedBuffer) { buf.Myprintf("values(%v)", node.Name) } -func (node *ValuesFuncExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Name, - ) -} - -func (node *ValuesFuncExpr) replace(from, to Expr) bool { - return false -} - -// SubstrExpr represents a call to SubstrExpr(column, value_expression) or SubstrExpr(column, value_expression,value_expression) -// also supported syntax SubstrExpr(column from value_expression for value_expression). -// Additionally to column names, SubstrExpr is also supported for string values, e.g.: -// SubstrExpr('static string value', value_expression, value_expression) -// In this case StrVal will be set instead of Name. -type SubstrExpr struct { - Name *ColName - StrVal *SQLVal - From Expr - To Expr -} - // Format formats the node. func (node *SubstrExpr) Format(buf *TrackedBuffer) { var val interface{} @@ -3158,89 +1567,16 @@ func (node *SubstrExpr) Format(buf *TrackedBuffer) { } } -func (node *SubstrExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.From, &node.To) -} - -func (node *SubstrExpr) walkSubtree(visit Visit) error { - if node == nil || node.Name == nil { - return nil - } - return Walk( - visit, - node.Name, - node.From, - node.To, - ) -} - -// ConvertExpr represents a call to CONVERT(expr, type) -// or it's equivalent CAST(expr AS type). Both are rewritten to the former. -type ConvertExpr struct { - Expr Expr - Type *ConvertType -} - // Format formats the node. func (node *ConvertExpr) Format(buf *TrackedBuffer) { buf.Myprintf("convert(%v, %v)", node.Expr, node.Type) } -func (node *ConvertExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - node.Type, - ) -} - -func (node *ConvertExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Expr) -} - -// ConvertUsingExpr represents a call to CONVERT(expr USING charset). -type ConvertUsingExpr struct { - Expr Expr - Type string -} - // Format formats the node. func (node *ConvertUsingExpr) Format(buf *TrackedBuffer) { buf.Myprintf("convert(%v using %s)", node.Expr, node.Type) } -func (node *ConvertUsingExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - ) -} - -func (node *ConvertUsingExpr) replace(from, to Expr) bool { - return replaceExprs(from, to, &node.Expr) -} - -// ConvertType represents the type in call to CONVERT(expr, type) -type ConvertType struct { - Type string - Length *SQLVal - Scale *SQLVal - Operator string - Charset string -} - -// this string is "character set" and this comment is required -const ( - CharacterSetStr = " character set" - CharsetStr = "charset" -) - // Format formats the node. func (node *ConvertType) Format(buf *TrackedBuffer) { buf.Myprintf("%s", node.Type) @@ -3256,61 +1592,11 @@ func (node *ConvertType) Format(buf *TrackedBuffer) { } } -func (node *ConvertType) walkSubtree(visit Visit) error { - return nil -} - -// MatchExpr represents a call to the MATCH function -type MatchExpr struct { - Columns SelectExprs - Expr Expr - Option string -} - -// MatchExpr.Option -const ( - BooleanModeStr = " in boolean mode" - NaturalLanguageModeStr = " in natural language mode" - NaturalLanguageModeWithQueryExpansionStr = " in natural language mode with query expansion" - QueryExpansionStr = " with query expansion" -) - // Format formats the node func (node *MatchExpr) Format(buf *TrackedBuffer) { buf.Myprintf("match(%v) against (%v%s)", node.Columns, node.Expr, node.Option) } -func (node *MatchExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Columns, - node.Expr, - ) -} - -func (node *MatchExpr) replace(from, to Expr) bool { - for _, sel := range node.Columns { - aliased, ok := sel.(*AliasedExpr) - if !ok { - continue - } - if replaceExprs(from, to, &aliased.Expr) { - return true - } - } - return replaceExprs(from, to, &node.Expr) -} - -// CaseExpr represents a CASE expression. -type CaseExpr struct { - Expr Expr - Whens []*When - Else Expr -} - // Format formats the node. func (node *CaseExpr) Format(buf *TrackedBuffer) { buf.Myprintf("case ") @@ -3326,35 +1612,6 @@ func (node *CaseExpr) Format(buf *TrackedBuffer) { buf.Myprintf("end") } -func (node *CaseExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - if err := Walk(visit, node.Expr); err != nil { - return err - } - for _, n := range node.Whens { - if err := Walk(visit, n); err != nil { - return err - } - } - return Walk(visit, node.Else) -} - -func (node *CaseExpr) replace(from, to Expr) bool { - for _, when := range node.Whens { - if replaceExprs(from, to, &when.Cond, &when.Val) { - return true - } - } - return replaceExprs(from, to, &node.Expr, &node.Else) -} - -// Default represents a DEFAULT expression. -type Default struct { - ColName string -} - // Format formats the node. func (node *Default) Format(buf *TrackedBuffer) { buf.Myprintf("default") @@ -3363,39 +1620,11 @@ func (node *Default) Format(buf *TrackedBuffer) { } } -func (node *Default) walkSubtree(visit Visit) error { - return nil -} - -func (node *Default) replace(from, to Expr) bool { - return false -} - -// When represents a WHEN sub-expression. -type When struct { - Cond Expr - Val Expr -} - // Format formats the node. func (node *When) Format(buf *TrackedBuffer) { buf.Myprintf("when %v then %v", node.Cond, node.Val) } -func (node *When) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Cond, - node.Val, - ) -} - -// GroupBy represents a GROUP BY clause. -type GroupBy []Expr - // Format formats the node. func (node GroupBy) Format(buf *TrackedBuffer) { prefix := " group by " @@ -3405,18 +1634,6 @@ func (node GroupBy) Format(buf *TrackedBuffer) { } } -func (node GroupBy) walkSubtree(visit Visit) error { - for _, n := range node { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// OrderBy represents an ORDER By clause. -type OrderBy []*Order - // Format formats the node. func (node OrderBy) Format(buf *TrackedBuffer) { prefix := " order by " @@ -3426,27 +1643,6 @@ func (node OrderBy) Format(buf *TrackedBuffer) { } } -func (node OrderBy) walkSubtree(visit Visit) error { - for _, n := range node { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// Order represents an ordering expression. -type Order struct { - Expr Expr - Direction string -} - -// Order.Direction -const ( - AscScr = "asc" - DescScr = "desc" -) - // Format formats the node. func (node *Order) Format(buf *TrackedBuffer) { if node, ok := node.Expr.(*NullVal); ok { @@ -3463,21 +1659,6 @@ func (node *Order) Format(buf *TrackedBuffer) { buf.Myprintf("%v %s", node.Expr, node.Direction) } -func (node *Order) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Expr, - ) -} - -// Limit represents a LIMIT clause. -type Limit struct { - Offset, Rowcount Expr -} - // Format formats the node. func (node *Limit) Format(buf *TrackedBuffer) { if node == nil { @@ -3490,20 +1671,6 @@ func (node *Limit) Format(buf *TrackedBuffer) { buf.Myprintf("%v", node.Rowcount) } -func (node *Limit) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Offset, - node.Rowcount, - ) -} - -// Values represents a VALUES clause. -type Values []ValTuple - // Format formats the node. func (node Values) Format(buf *TrackedBuffer) { prefix := "values " @@ -3513,18 +1680,6 @@ func (node Values) Format(buf *TrackedBuffer) { } } -func (node Values) walkSubtree(visit Visit) error { - for _, n := range node { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// UpdateExprs represents a list of update expressions. -type UpdateExprs []*UpdateExpr - // Format formats the node. func (node UpdateExprs) Format(buf *TrackedBuffer) { var prefix string @@ -3534,40 +1689,11 @@ func (node UpdateExprs) Format(buf *TrackedBuffer) { } } -func (node UpdateExprs) walkSubtree(visit Visit) error { - for _, n := range node { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// UpdateExpr represents an update expression. -type UpdateExpr struct { - Name *ColName - Expr Expr -} - // Format formats the node. func (node *UpdateExpr) Format(buf *TrackedBuffer) { buf.Myprintf("%v = %v", node.Name, node.Expr) } -func (node *UpdateExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Name, - node.Expr, - ) -} - -// SetExprs represents a list of set expressions. -type SetExprs []*SetExpr - // Format formats the node. func (node SetExprs) Format(buf *TrackedBuffer) { var prefix string @@ -3577,35 +1703,6 @@ func (node SetExprs) Format(buf *TrackedBuffer) { } } -func (node SetExprs) walkSubtree(visit Visit) error { - for _, n := range node { - if err := Walk(visit, n); err != nil { - return err - } - } - return nil -} - -// SetExpr represents a set expression. -type SetExpr struct { - Name ColIdent - Expr Expr -} - -// SetExpr.Expr, for SET TRANSACTION ... or START TRANSACTION -const ( - // TransactionStr is the Name for a SET TRANSACTION statement - TransactionStr = "transaction" - - IsolationLevelReadUncommitted = "isolation level read uncommitted" - IsolationLevelReadCommitted = "isolation level read committed" - IsolationLevelRepeatableRead = "isolation level repeatable read" - IsolationLevelSerializable = "isolation level serializable" - - TxReadOnly = "read only" - TxReadWrite = "read write" -) - // Format formats the node. func (node *SetExpr) Format(buf *TrackedBuffer) { // We don't have to backtick set variable names. @@ -3619,20 +1716,6 @@ func (node *SetExpr) Format(buf *TrackedBuffer) { } } -func (node *SetExpr) walkSubtree(visit Visit) error { - if node == nil { - return nil - } - return Walk( - visit, - node.Name, - node.Expr, - ) -} - -// OnDup represents an ON DUPLICATE KEY clause. -type OnDup UpdateExprs - // Format formats the node. func (node OnDup) Format(buf *TrackedBuffer) { if node == nil { @@ -3641,189 +1724,12 @@ func (node OnDup) Format(buf *TrackedBuffer) { buf.Myprintf(" on duplicate key update %v", UpdateExprs(node)) } -func (node OnDup) walkSubtree(visit Visit) error { - return Walk(visit, UpdateExprs(node)) -} - -// ColIdent is a case insensitive SQL identifier. It will be escaped with -// backquotes if necessary. -type ColIdent struct { - // This artifact prevents this struct from being compared - // with itself. It consumes no space as long as it's not the - // last field in the struct. - _ [0]struct{ _ []byte } - val, lowered string -} - -// NewColIdent makes a new ColIdent. -func NewColIdent(str string) ColIdent { - return ColIdent{ - val: str, - } -} - // Format formats the node. func (node ColIdent) Format(buf *TrackedBuffer) { formatID(buf, node.val, node.Lowered()) } -func (node ColIdent) walkSubtree(visit Visit) error { - return nil -} - -// IsEmpty returns true if the name is empty. -func (node ColIdent) IsEmpty() bool { - return node.val == "" -} - -// String returns the unescaped column name. It must -// not be used for SQL generation. Use sqlparser.String -// instead. The Stringer conformance is for usage -// in templates. -func (node ColIdent) String() string { - return node.val -} - -// CompliantName returns a compliant id name -// that can be used for a bind var. -func (node ColIdent) CompliantName() string { - return compliantName(node.val) -} - -// Lowered returns a lower-cased column name. -// This function should generally be used only for optimizing -// comparisons. -func (node ColIdent) Lowered() string { - if node.val == "" { - return "" - } - if node.lowered == "" { - node.lowered = strings.ToLower(node.val) - } - return node.lowered -} - -// Equal performs a case-insensitive compare. -func (node ColIdent) Equal(in ColIdent) bool { - return node.Lowered() == in.Lowered() -} - -// EqualString performs a case-insensitive compare with str. -func (node ColIdent) EqualString(str string) bool { - return node.Lowered() == strings.ToLower(str) -} - -// MarshalJSON marshals into JSON. -func (node ColIdent) MarshalJSON() ([]byte, error) { - return json.Marshal(node.val) -} - -// UnmarshalJSON unmarshals from JSON. -func (node *ColIdent) UnmarshalJSON(b []byte) error { - var result string - err := json.Unmarshal(b, &result) - if err != nil { - return err - } - node.val = result - return nil -} - -// TableIdent is a case sensitive SQL identifier. It will be escaped with -// backquotes if necessary. -type TableIdent struct { - v string -} - -// NewTableIdent creates a new TableIdent. -func NewTableIdent(str string) TableIdent { - return TableIdent{v: str} -} - // Format formats the node. func (node TableIdent) Format(buf *TrackedBuffer) { formatID(buf, node.v, strings.ToLower(node.v)) } - -func (node TableIdent) walkSubtree(visit Visit) error { - return nil -} - -// IsEmpty returns true if TabIdent is empty. -func (node TableIdent) IsEmpty() bool { - return node.v == "" -} - -// String returns the unescaped table name. It must -// not be used for SQL generation. Use sqlparser.String -// instead. The Stringer conformance is for usage -// in templates. -func (node TableIdent) String() string { - return node.v -} - -// CompliantName returns a compliant id name -// that can be used for a bind var. -func (node TableIdent) CompliantName() string { - return compliantName(node.v) -} - -// MarshalJSON marshals into JSON. -func (node TableIdent) MarshalJSON() ([]byte, error) { - return json.Marshal(node.v) -} - -// UnmarshalJSON unmarshals from JSON. -func (node *TableIdent) UnmarshalJSON(b []byte) error { - var result string - err := json.Unmarshal(b, &result) - if err != nil { - return err - } - node.v = result - return nil -} - -func formatID(buf *TrackedBuffer, original, lowered string) { - isDbSystemVariable := false - if len(original) > 1 && original[:2] == "@@" { - isDbSystemVariable = true - } - - for i, c := range original { - if !isLetter(uint16(c)) && (!isDbSystemVariable || !isCarat(uint16(c))) { - if i == 0 || !isDigit(uint16(c)) { - goto mustEscape - } - } - } - if _, ok := keywords[lowered]; ok { - goto mustEscape - } - buf.Myprintf("%s", original) - return - -mustEscape: - buf.WriteByte('`') - for _, c := range original { - buf.WriteRune(c) - if c == '`' { - buf.WriteByte('`') - } - } - buf.WriteByte('`') -} - -func compliantName(in string) string { - var buf strings.Builder - for i, c := range in { - if !isLetter(uint16(c)) { - if i == 0 || !isDigit(uint16(c)) { - buf.WriteByte('_') - continue - } - } - buf.WriteRune(c) - } - return buf.String() -} diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go new file mode 100644 index 00000000000..a746301fdce --- /dev/null +++ b/go/vt/sqlparser/ast_funcs.go @@ -0,0 +1,750 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +import ( + "encoding/hex" + "encoding/json" + "strings" + + "vitess.io/vitess/go/vt/log" + + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vterrors" + + querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" +) + +// Walk calls visit on every node. +// If visit returns true, the underlying nodes +// are also visited. If it returns an error, walking +// is interrupted, and the error is returned. +func Walk(visit Visit, nodes ...SQLNode) error { + for _, node := range nodes { + if node == nil { + continue + } + var err error + var kontinue bool + pre := func(cursor *Cursor) bool { + // If we already have found an error, don't visit these nodes, just exit early + if err != nil { + return false + } + kontinue, err = visit(cursor.Node()) + if err != nil { + return true // we have to return true here so that post gets called + } + return kontinue + } + post := func(cursor *Cursor) bool { + if err != nil { + return false // now we can abort the traversal if an error was found + } + + return true + } + + Rewrite(node, pre, post) + if err != nil { + return err + } + } + return nil +} + +// Visit defines the signature of a function that +// can be used to visit all nodes of a parse tree. +type Visit func(node SQLNode) (kontinue bool, err error) + +// Append appends the SQLNode to the buffer. +func Append(buf *strings.Builder, node SQLNode) { + tbuf := &TrackedBuffer{ + Builder: buf, + } + node.Format(tbuf) +} + +// IndexColumn describes a column in an index definition with optional length +type IndexColumn struct { + Column ColIdent + Length *SQLVal +} + +// LengthScaleOption is used for types that have an optional length +// and scale +type LengthScaleOption struct { + Length *SQLVal + Scale *SQLVal +} + +// IndexOption is used for trailing options for indexes: COMMENT, KEY_BLOCK_SIZE, USING +type IndexOption struct { + Name string + Value *SQLVal + Using string +} + +// ColumnKeyOption indicates whether or not the given column is defined as an +// index element and contains the type of the option +type ColumnKeyOption int + +const ( + colKeyNone ColumnKeyOption = iota + colKeyPrimary + colKeySpatialKey + colKeyUnique + colKeyUniqueKey + colKey +) + +// ReferenceAction indicates the action takes by a referential constraint e.g. +// the `CASCADE` in a `FOREIGN KEY .. ON DELETE CASCADE` table definition. +type ReferenceAction int + +// These map to the SQL-defined reference actions. +// See https://dev.mysql.com/doc/refman/8.0/en/create-table-foreign-keys.html#foreign-keys-referential-actions +const ( + // DefaultAction indicates no action was explicitly specified. + DefaultAction ReferenceAction = iota + Restrict + Cascade + NoAction + SetNull + SetDefault +) + +// ShowTablesOpt is show tables option +type ShowTablesOpt struct { + Full string + DbName string + Filter *ShowFilter +} + +// ValType specifies the type for SQLVal. +type ValType int + +// These are the possible Valtype values. +// HexNum represents a 0x... value. It cannot +// be treated as a simple value because it can +// be interpreted differently depending on the +// context. +const ( + StrVal = ValType(iota) + IntVal + FloatVal + HexNum + HexVal + ValArg + BitVal +) + +// AffectedTables returns the list table names affected by the DDL. +func (node *DDL) AffectedTables() TableNames { + if node.Action == RenameStr || node.Action == DropStr { + list := make(TableNames, 0, len(node.FromTables)+len(node.ToTables)) + list = append(list, node.FromTables...) + list = append(list, node.ToTables...) + return list + } + return TableNames{node.Table} +} + +// AddColumn appends the given column to the list in the spec +func (ts *TableSpec) AddColumn(cd *ColumnDefinition) { + ts.Columns = append(ts.Columns, cd) +} + +// AddIndex appends the given index to the list in the spec +func (ts *TableSpec) AddIndex(id *IndexDefinition) { + ts.Indexes = append(ts.Indexes, id) +} + +// AddConstraint appends the given index to the list in the spec +func (ts *TableSpec) AddConstraint(cd *ConstraintDefinition) { + ts.Constraints = append(ts.Constraints, cd) +} + +// DescribeType returns the abbreviated type information as required for +// describe table +func (ct *ColumnType) DescribeType() string { + buf := NewTrackedBuffer(nil) + buf.Myprintf("%s", ct.Type) + if ct.Length != nil && ct.Scale != nil { + buf.Myprintf("(%v,%v)", ct.Length, ct.Scale) + } else if ct.Length != nil { + buf.Myprintf("(%v)", ct.Length) + } + + opts := make([]string, 0, 16) + if ct.Unsigned { + opts = append(opts, keywordStrings[UNSIGNED]) + } + if ct.Zerofill { + opts = append(opts, keywordStrings[ZEROFILL]) + } + if len(opts) != 0 { + buf.Myprintf(" %s", strings.Join(opts, " ")) + } + return buf.String() +} + +// SQLType returns the sqltypes type code for the given column +func (ct *ColumnType) SQLType() querypb.Type { + switch ct.Type { + case keywordStrings[TINYINT]: + if ct.Unsigned { + return sqltypes.Uint8 + } + return sqltypes.Int8 + case keywordStrings[SMALLINT]: + if ct.Unsigned { + return sqltypes.Uint16 + } + return sqltypes.Int16 + case keywordStrings[MEDIUMINT]: + if ct.Unsigned { + return sqltypes.Uint24 + } + return sqltypes.Int24 + case keywordStrings[INT], keywordStrings[INTEGER]: + if ct.Unsigned { + return sqltypes.Uint32 + } + return sqltypes.Int32 + case keywordStrings[BIGINT]: + if ct.Unsigned { + return sqltypes.Uint64 + } + return sqltypes.Int64 + case keywordStrings[BOOL], keywordStrings[BOOLEAN]: + return sqltypes.Uint8 + case keywordStrings[TEXT]: + return sqltypes.Text + case keywordStrings[TINYTEXT]: + return sqltypes.Text + case keywordStrings[MEDIUMTEXT]: + return sqltypes.Text + case keywordStrings[LONGTEXT]: + return sqltypes.Text + case keywordStrings[BLOB]: + return sqltypes.Blob + case keywordStrings[TINYBLOB]: + return sqltypes.Blob + case keywordStrings[MEDIUMBLOB]: + return sqltypes.Blob + case keywordStrings[LONGBLOB]: + return sqltypes.Blob + case keywordStrings[CHAR]: + return sqltypes.Char + case keywordStrings[VARCHAR]: + return sqltypes.VarChar + case keywordStrings[BINARY]: + return sqltypes.Binary + case keywordStrings[VARBINARY]: + return sqltypes.VarBinary + case keywordStrings[DATE]: + return sqltypes.Date + case keywordStrings[TIME]: + return sqltypes.Time + case keywordStrings[DATETIME]: + return sqltypes.Datetime + case keywordStrings[TIMESTAMP]: + return sqltypes.Timestamp + case keywordStrings[YEAR]: + return sqltypes.Year + case keywordStrings[FLOAT_TYPE]: + return sqltypes.Float32 + case keywordStrings[DOUBLE]: + return sqltypes.Float64 + case keywordStrings[DECIMAL]: + return sqltypes.Decimal + case keywordStrings[BIT]: + return sqltypes.Bit + case keywordStrings[ENUM]: + return sqltypes.Enum + case keywordStrings[SET]: + return sqltypes.Set + case keywordStrings[JSON]: + return sqltypes.TypeJSON + case keywordStrings[GEOMETRY]: + return sqltypes.Geometry + case keywordStrings[POINT]: + return sqltypes.Geometry + case keywordStrings[LINESTRING]: + return sqltypes.Geometry + case keywordStrings[POLYGON]: + return sqltypes.Geometry + case keywordStrings[GEOMETRYCOLLECTION]: + return sqltypes.Geometry + case keywordStrings[MULTIPOINT]: + return sqltypes.Geometry + case keywordStrings[MULTILINESTRING]: + return sqltypes.Geometry + case keywordStrings[MULTIPOLYGON]: + return sqltypes.Geometry + } + panic("unimplemented type " + ct.Type) +} + +// ParseParams parses the vindex parameter list, pulling out the special-case +// "owner" parameter +func (node *VindexSpec) ParseParams() (string, map[string]string) { + var owner string + params := map[string]string{} + for _, p := range node.Params { + if p.Key.Lowered() == VindexOwnerStr { + owner = p.Val + } else { + params[p.Key.String()] = p.Val + } + } + return owner, params +} + +var _ ConstraintInfo = &ForeignKeyDefinition{} + +func (f *ForeignKeyDefinition) iConstraintInfo() {} + +// HasOnTable returns true if the show statement has an "on" clause +func (node *Show) HasOnTable() bool { + return node.OnTable.Name.v != "" +} + +// HasTable returns true if the show statement has a parsed table name. +// Not all show statements parse table names. +func (node *Show) HasTable() bool { + return node.Table.Name.v != "" +} + +// FindColumn finds a column in the column list, returning +// the index if it exists or -1 otherwise +func (node Columns) FindColumn(col ColIdent) int { + for i, colName := range node { + if colName.Equal(col) { + return i + } + } + return -1 +} + +// RemoveHints returns a new AliasedTableExpr with the hints removed. +func (node *AliasedTableExpr) RemoveHints() *AliasedTableExpr { + noHints := *node + noHints.Hints = nil + return &noHints +} + +// IsEmpty returns true if TableName is nil or empty. +func (node TableName) IsEmpty() bool { + // If Name is empty, Qualifier is also empty. + return node.Name.IsEmpty() +} + +// ToViewName returns a TableName acceptable for use as a VIEW. VIEW names are +// always lowercase, so ToViewName lowercasese the name. Databases are case-sensitive +// so Qualifier is left untouched. +func (node TableName) ToViewName() TableName { + return TableName{ + Qualifier: node.Qualifier, + Name: NewTableIdent(strings.ToLower(node.Name.v)), + } +} + +// NewWhere creates a WHERE or HAVING clause out +// of a Expr. If the expression is nil, it returns nil. +func NewWhere(typ string, expr Expr) *Where { + if expr == nil { + return nil + } + return &Where{Type: typ, Expr: expr} +} + +// ReplaceExpr finds the from expression from root +// and replaces it with to. If from matches root, +// then to is returned. +func ReplaceExpr(root, from, to Expr) Expr { + tmp := Rewrite(root, replaceExpr(from, to), nil) + expr, success := tmp.(Expr) + if !success { + log.Errorf("Failed to rewrite expression. Rewriter returned a non-expression: " + String(tmp)) + return from + } + + return expr +} + +func replaceExpr(from, to Expr) func(cursor *Cursor) bool { + return func(cursor *Cursor) bool { + if cursor.Node() == from { + cursor.Replace(to) + } + switch cursor.Node().(type) { + case *ExistsExpr, *SQLVal, *Subquery, *ValuesFuncExpr, *Default: + return false + } + + return true + } +} + +// IsImpossible returns true if the comparison in the expression can never evaluate to true. +// Note that this is not currently exhaustive to ALL impossible comparisons. +func (node *ComparisonExpr) IsImpossible() bool { + var left, right *SQLVal + var ok bool + if left, ok = node.Left.(*SQLVal); !ok { + return false + } + if right, ok = node.Right.(*SQLVal); !ok { + return false + } + if node.Operator == NotEqualStr && left.Type == right.Type { + if len(left.Val) != len(right.Val) { + return false + } + + for i := range left.Val { + if left.Val[i] != right.Val[i] { + return false + } + } + return true + } + return false +} + +// ExprFromValue converts the given Value into an Expr or returns an error. +func ExprFromValue(value sqltypes.Value) (Expr, error) { + // The type checks here follow the rules defined in sqltypes/types.go. + switch { + case value.Type() == sqltypes.Null: + return &NullVal{}, nil + case value.IsIntegral(): + return NewIntVal(value.ToBytes()), nil + case value.IsFloat() || value.Type() == sqltypes.Decimal: + return NewFloatVal(value.ToBytes()), nil + case value.IsQuoted(): + return NewStrVal(value.ToBytes()), nil + default: + // We cannot support sqltypes.Expression, or any other invalid type. + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot convert value %v to AST", value) + } +} + +// NewStrVal builds a new StrVal. +func NewStrVal(in []byte) *SQLVal { + return &SQLVal{Type: StrVal, Val: in} +} + +// NewIntVal builds a new IntVal. +func NewIntVal(in []byte) *SQLVal { + return &SQLVal{Type: IntVal, Val: in} +} + +// NewFloatVal builds a new FloatVal. +func NewFloatVal(in []byte) *SQLVal { + return &SQLVal{Type: FloatVal, Val: in} +} + +// NewHexNum builds a new HexNum. +func NewHexNum(in []byte) *SQLVal { + return &SQLVal{Type: HexNum, Val: in} +} + +// NewHexVal builds a new HexVal. +func NewHexVal(in []byte) *SQLVal { + return &SQLVal{Type: HexVal, Val: in} +} + +// NewBitVal builds a new BitVal containing a bit literal. +func NewBitVal(in []byte) *SQLVal { + return &SQLVal{Type: BitVal, Val: in} +} + +// NewValArg builds a new ValArg. +func NewValArg(in []byte) *SQLVal { + return &SQLVal{Type: ValArg, Val: in} +} + +// HexDecode decodes the hexval into bytes. +func (node *SQLVal) HexDecode() ([]byte, error) { + dst := make([]byte, hex.DecodedLen(len([]byte(node.Val)))) + _, err := hex.Decode(dst, []byte(node.Val)) + if err != nil { + return nil, err + } + return dst, err +} + +// Equal returns true if the column names match. +func (node *ColName) Equal(c *ColName) bool { + // Failsafe: ColName should not be empty. + if node == nil || c == nil { + return false + } + return node.Name.Equal(c.Name) && node.Qualifier == c.Qualifier +} + +// Aggregates is a map of all aggregate functions. +var Aggregates = map[string]bool{ + "avg": true, + "bit_and": true, + "bit_or": true, + "bit_xor": true, + "count": true, + "group_concat": true, + "max": true, + "min": true, + "std": true, + "stddev_pop": true, + "stddev_samp": true, + "stddev": true, + "sum": true, + "var_pop": true, + "var_samp": true, + "variance": true, +} + +// IsAggregate returns true if the function is an aggregate. +func (node *FuncExpr) IsAggregate() bool { + return Aggregates[node.Name.Lowered()] +} + +// NewColIdent makes a new ColIdent. +func NewColIdent(str string) ColIdent { + return ColIdent{ + val: str, + } +} + +// IsEmpty returns true if the name is empty. +func (node ColIdent) IsEmpty() bool { + return node.val == "" +} + +// String returns the unescaped column name. It must +// not be used for SQL generation. Use sqlparser.String +// instead. The Stringer conformance is for usage +// in templates. +func (node ColIdent) String() string { + return node.val +} + +// CompliantName returns a compliant id name +// that can be used for a bind var. +func (node ColIdent) CompliantName() string { + return compliantName(node.val) +} + +// Lowered returns a lower-cased column name. +// This function should generally be used only for optimizing +// comparisons. +func (node ColIdent) Lowered() string { + if node.val == "" { + return "" + } + if node.lowered == "" { + node.lowered = strings.ToLower(node.val) + } + return node.lowered +} + +// Equal performs a case-insensitive compare. +func (node ColIdent) Equal(in ColIdent) bool { + return node.Lowered() == in.Lowered() +} + +// EqualString performs a case-insensitive compare with str. +func (node ColIdent) EqualString(str string) bool { + return node.Lowered() == strings.ToLower(str) +} + +// MarshalJSON marshals into JSON. +func (node ColIdent) MarshalJSON() ([]byte, error) { + return json.Marshal(node.val) +} + +// UnmarshalJSON unmarshals from JSON. +func (node *ColIdent) UnmarshalJSON(b []byte) error { + var result string + err := json.Unmarshal(b, &result) + if err != nil { + return err + } + node.val = result + return nil +} + +// NewTableIdent creates a new TableIdent. +func NewTableIdent(str string) TableIdent { + return TableIdent{v: str} +} + +// IsEmpty returns true if TabIdent is empty. +func (node TableIdent) IsEmpty() bool { + return node.v == "" +} + +// String returns the unescaped table name. It must +// not be used for SQL generation. Use sqlparser.String +// instead. The Stringer conformance is for usage +// in templates. +func (node TableIdent) String() string { + return node.v +} + +// CompliantName returns a compliant id name +// that can be used for a bind var. +func (node TableIdent) CompliantName() string { + return compliantName(node.v) +} + +// MarshalJSON marshals into JSON. +func (node TableIdent) MarshalJSON() ([]byte, error) { + return json.Marshal(node.v) +} + +// UnmarshalJSON unmarshals from JSON. +func (node *TableIdent) UnmarshalJSON(b []byte) error { + var result string + err := json.Unmarshal(b, &result) + if err != nil { + return err + } + node.v = result + return nil +} + +func formatID(buf *TrackedBuffer, original, lowered string) { + isDbSystemVariable := false + if len(original) > 1 && original[:2] == "@@" { + isDbSystemVariable = true + } + + for i, c := range original { + if !isLetter(uint16(c)) && (!isDbSystemVariable || !isCarat(uint16(c))) { + if i == 0 || !isDigit(uint16(c)) { + goto mustEscape + } + } + } + if _, ok := keywords[lowered]; ok { + goto mustEscape + } + buf.Myprintf("%s", original) + return + +mustEscape: + buf.WriteByte('`') + for _, c := range original { + buf.WriteRune(c) + if c == '`' { + buf.WriteByte('`') + } + } + buf.WriteByte('`') +} + +func compliantName(in string) string { + var buf strings.Builder + for i, c := range in { + if !isLetter(uint16(c)) { + if i == 0 || !isDigit(uint16(c)) { + buf.WriteByte('_') + continue + } + } + buf.WriteRune(c) + } + return buf.String() +} + +// AddOrder adds an order by element +func (node *Select) AddOrder(order *Order) { + node.OrderBy = append(node.OrderBy, order) +} + +// SetLimit sets the limit clause +func (node *Select) SetLimit(limit *Limit) { + node.Limit = limit +} + +// AddWhere adds the boolean expression to the +// WHERE clause as an AND condition. If the expression +// is an OR clause, it parenthesizes it. Currently, +// the OR operator is the only one that's lower precedence +// than AND. +func (node *Select) AddWhere(expr Expr) { + if _, ok := expr.(*OrExpr); ok { + expr = &ParenExpr{Expr: expr} + } + if node.Where == nil { + node.Where = &Where{ + Type: WhereStr, + Expr: expr, + } + return + } + node.Where.Expr = &AndExpr{ + Left: node.Where.Expr, + Right: expr, + } +} + +// AddHaving adds the boolean expression to the +// HAVING clause as an AND condition. If the expression +// is an OR clause, it parenthesizes it. Currently, +// the OR operator is the only one that's lower precedence +// than AND. +func (node *Select) AddHaving(expr Expr) { + if _, ok := expr.(*OrExpr); ok { + expr = &ParenExpr{Expr: expr} + } + if node.Having == nil { + node.Having = &Where{ + Type: HavingStr, + Expr: expr, + } + return + } + node.Having.Expr = &AndExpr{ + Left: node.Having.Expr, + Right: expr, + } +} + +// AddOrder adds an order by element +func (node *ParenSelect) AddOrder(order *Order) { + panic("unreachable") +} + +// SetLimit sets the limit clause +func (node *ParenSelect) SetLimit(limit *Limit) { + panic("unreachable") +} + +// AddOrder adds an order by element +func (node *Union) AddOrder(order *Order) { + node.OrderBy = append(node.OrderBy, order) +} + +// SetLimit sets the limit clause +func (node *Union) SetLimit(limit *Limit) { + node.Limit = limit +} diff --git a/go/vt/sqlparser/constants.go b/go/vt/sqlparser/constants.go new file mode 100644 index 00000000000..b194967fa3b --- /dev/null +++ b/go/vt/sqlparser/constants.go @@ -0,0 +1,163 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +const ( + // Select.Distinct + DistinctStr = "distinct " + StraightJoinHint = "straight_join " + + // Select.Lock + ForUpdateStr = " for update" + ShareModeStr = " lock in share mode" + + // Select.Cache + SQLCacheStr = "sql_cache " + SQLNoCacheStr = "sql_no_cache " + + // Union.Type + UnionStr = "union" + UnionAllStr = "union all" + UnionDistinctStr = "union distinct" + + // DDL strings. + InsertStr = "insert" + ReplaceStr = "replace" + + // Set.Scope or Show.Scope + SessionStr = "session" + GlobalStr = "global" + VitessMetadataStr = "vitess_metadata" + ImplicitStr = "" + + // DDL strings. + CreateStr = "create" + AlterStr = "alter" + DropStr = "drop" + RenameStr = "rename" + TruncateStr = "truncate" + FlushStr = "flush" + CreateVindexStr = "create vindex" + DropVindexStr = "drop vindex" + AddVschemaTableStr = "add vschema table" + DropVschemaTableStr = "drop vschema table" + AddColVindexStr = "on table add vindex" + DropColVindexStr = "on table drop vindex" + AddSequenceStr = "add sequence" + AddAutoIncStr = "add auto_increment" + + // Vindex DDL param to specify the owner of a vindex + VindexOwnerStr = "owner" + + // Partition strings + ReorganizeStr = "reorganize partition" + + // JoinTableExpr.Join + JoinStr = "join" + StraightJoinStr = "straight_join" + LeftJoinStr = "left join" + RightJoinStr = "right join" + NaturalJoinStr = "natural join" + NaturalLeftJoinStr = "natural left join" + NaturalRightJoinStr = "natural right join" + + // Index hints. + UseStr = "use " + IgnoreStr = "ignore " + ForceStr = "force " + + // Where.Type + WhereStr = "where" + HavingStr = "having" + + // ComparisonExpr.Operator + EqualStr = "=" + LessThanStr = "<" + GreaterThanStr = ">" + LessEqualStr = "<=" + GreaterEqualStr = ">=" + NotEqualStr = "!=" + NullSafeEqualStr = "<=>" + InStr = "in" + NotInStr = "not in" + LikeStr = "like" + NotLikeStr = "not like" + RegexpStr = "regexp" + NotRegexpStr = "not regexp" + JSONExtractOp = "->" + JSONUnquoteExtractOp = "->>" + + // RangeCond.Operator + BetweenStr = "between" + NotBetweenStr = "not between" + + // IsExpr.Operator + IsNullStr = "is null" + IsNotNullStr = "is not null" + IsTrueStr = "is true" + IsNotTrueStr = "is not true" + IsFalseStr = "is false" + IsNotFalseStr = "is not false" + + // BinaryExpr.Operator + BitAndStr = "&" + BitOrStr = "|" + BitXorStr = "^" + PlusStr = "+" + MinusStr = "-" + MultStr = "*" + DivStr = "/" + IntDivStr = "div" + ModStr = "%" + ShiftLeftStr = "<<" + ShiftRightStr = ">>" + + // UnaryExpr.Operator + UPlusStr = "+" + UMinusStr = "-" + TildaStr = "~" + BangStr = "!" + BinaryStr = "binary " + UBinaryStr = "_binary " + Utf8mb4Str = "_utf8mb4 " + + // this string is "character set" and this comment is required + CharacterSetStr = " character set" + CharsetStr = "charset" + + // MatchExpr.Option + BooleanModeStr = " in boolean mode" + NaturalLanguageModeStr = " in natural language mode" + NaturalLanguageModeWithQueryExpansionStr = " in natural language mode with query expansion" + QueryExpansionStr = " with query expansion" + + // Order.Direction + AscScr = "asc" + DescScr = "desc" + + // SetExpr.Expr, for SET TRANSACTION ... or START TRANSACTION + // TransactionStr is the Name for a SET TRANSACTION statement + TransactionStr = "transaction" + + IsolationLevelReadUncommitted = "isolation level read uncommitted" + IsolationLevelReadCommitted = "isolation level read committed" + IsolationLevelRepeatableRead = "isolation level repeatable read" + IsolationLevelSerializable = "isolation level serializable" + + TxReadOnly = "read only" + TxReadWrite = "read write" +) diff --git a/go/vt/sqlparser/expression_rewriting.go b/go/vt/sqlparser/expression_rewriting.go new file mode 100644 index 00000000000..89a5c9ebfe8 --- /dev/null +++ b/go/vt/sqlparser/expression_rewriting.go @@ -0,0 +1,112 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +import ( + "vitess.io/vitess/go/vt/log" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +// PrepareAST will normalize the query +func PrepareAST(in Statement, bindVars map[string]*querypb.BindVariable, prefix string) (*RewriteASTResult, error) { + Normalize(in, bindVars, prefix) + return RewriteAST(in) +} + +// RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries +func RewriteAST(in Statement) (*RewriteASTResult, error) { + er := new(expressionRewriter) + Rewrite(in, er.goingDown, nil) + + return &RewriteASTResult{ + AST: in, + NeedLastInsertID: er.lastInsertID, + NeedDatabase: er.database, + }, nil +} + +// RewriteASTResult contains the rewritten ast and meta information about it +type RewriteASTResult struct { + AST Statement + NeedLastInsertID bool + NeedDatabase bool +} + +type expressionRewriter struct { + lastInsertID, database bool + err error +} + +const ( + //LastInsertIDName is a reserved bind var name for last_insert_id() + LastInsertIDName = "__lastInsertId" + //DBVarName is a reserved bind var name for database() + DBVarName = "__vtdbname" +) + +func (er *expressionRewriter) goingDown(cursor *Cursor) bool { + switch node := cursor.Node().(type) { + case *AliasedExpr: + if node.As.IsEmpty() { + buf := NewTrackedBuffer(nil) + node.Expr.Format(buf) + inner := new(expressionRewriter) + tmp := Rewrite(node.Expr, inner.goingDown, nil) + newExpr, ok := tmp.(Expr) + if !ok { + log.Errorf("failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) + return false + } + node.Expr = newExpr + er.database = er.database || inner.database + er.lastInsertID = er.lastInsertID || inner.lastInsertID + if inner.didAnythingChange() { + node.As = NewColIdent(buf.String()) + } + return false + } + + case *FuncExpr: + switch { + case node.Name.EqualString("last_insert_id"): + if len(node.Exprs) > 0 { + er.err = vterrors.New(vtrpc.Code_UNIMPLEMENTED, "Argument to LAST_INSERT_ID() not supported") + } else { + cursor.Replace(bindVarExpression(LastInsertIDName)) + er.lastInsertID = true + } + case node.Name.EqualString("database"): + if len(node.Exprs) > 0 { + er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. DATABASE() takes no arguments") + } else { + cursor.Replace(bindVarExpression(DBVarName)) + er.database = true + } + } + } + return true +} + +func (er *expressionRewriter) didAnythingChange() bool { + return er.database || er.lastInsertID +} + +func bindVarExpression(name string) *SQLVal { + return NewValArg([]byte(":" + name)) +} diff --git a/go/vt/sqlparser/expression_rewriting_test.go b/go/vt/sqlparser/expression_rewriting_test.go new file mode 100644 index 00000000000..9bd5cf89375 --- /dev/null +++ b/go/vt/sqlparser/expression_rewriting_test.go @@ -0,0 +1,92 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type myTestCase struct { + in, expected string + liid, db bool +} + +func TestRewrites(in *testing.T) { + tests := []myTestCase{ + { + in: "SELECT 42", + expected: "SELECT 42", + db: false, liid: false, + }, + { + in: "SELECT last_insert_id()", + expected: "SELECT :__lastInsertId as `last_insert_id()`", + db: false, liid: true, + }, + { + in: "SELECT database()", + expected: "SELECT :__vtdbname as `database()`", + db: true, liid: false, + }, + { + in: "SELECT last_insert_id() as test", + expected: "SELECT :__lastInsertId as test", + db: false, liid: true, + }, + { + in: "SELECT last_insert_id() + database()", + expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`", + db: true, liid: true, + }, + { + in: "select (select database() from test) from test", + expected: "select (select :__vtdbname as `database()` from test) as `(select database() from test)` from test", + db: true, liid: false, + }, + { + in: "select id from user where database()", + expected: "select id from user where :__vtdbname", + db: true, liid: false, + }, + } + + for _, tc := range tests { + in.Run(tc.in, func(t *testing.T) { + stmt, err := Parse(tc.in) + require.NoError(t, err) + + result, err := RewriteAST(stmt) + require.NoError(t, err) + + expected, err := Parse(tc.expected) + require.NoError(t, err) + + s := toString(expected) + require.Equal(t, s, toString(result.AST)) + require.Equal(t, tc.liid, result.NeedLastInsertID, "should need last insert id") + require.Equal(t, tc.db, result.NeedDatabase, "should need database name") + }) + } +} + +func toString(node SQLNode) string { + buf := NewTrackedBuffer(nil) + node.Format(buf) + return buf.String() +} diff --git a/go/vt/sqlparser/parser.go b/go/vt/sqlparser/parser.go new file mode 100644 index 00000000000..1885a408a70 --- /dev/null +++ b/go/vt/sqlparser/parser.go @@ -0,0 +1,222 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +import ( + "errors" + "fmt" + "io" + "sync" + + "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/vterrors" + + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" +) + +// parserPool is a pool for parser objects. +var parserPool = sync.Pool{} + +// zeroParser is a zero-initialized parser to help reinitialize the parser for pooling. +var zeroParser = *(yyNewParser().(*yyParserImpl)) + +// yyParsePooled is a wrapper around yyParse that pools the parser objects. There isn't a +// particularly good reason to use yyParse directly, since it immediately discards its parser. What +// would be ideal down the line is to actually pool the stacks themselves rather than the parser +// objects, as per https://github.com/cznic/goyacc/blob/master/main.go. However, absent an upstream +// change to goyacc, this is the next best option. +// +// N.B: Parser pooling means that you CANNOT take references directly to parse stack variables (e.g. +// $$ = &$4) in sql.y rules. You must instead add an intermediate reference like so: +// showCollationFilterOpt := $4 +// $$ = &Show{Type: string($2), ShowCollationFilterOpt: &showCollationFilterOpt} +func yyParsePooled(yylex yyLexer) int { + // Being very particular about using the base type and not an interface type b/c we depend on + // the implementation to know how to reinitialize the parser. + var parser *yyParserImpl + + i := parserPool.Get() + if i != nil { + parser = i.(*yyParserImpl) + } else { + parser = yyNewParser().(*yyParserImpl) + } + + defer func() { + *parser = zeroParser + parserPool.Put(parser) + }() + return parser.Parse(yylex) +} + +// Instructions for creating new types: If a type +// needs to satisfy an interface, declare that function +// along with that interface. This will help users +// identify the list of types to which they can assert +// those interfaces. +// If the member of a type has a string with a predefined +// list of values, declare those values as const following +// the type. +// For interfaces that define dummy functions to consolidate +// a set of types, define the function as iTypeName. +// This will help avoid name collisions. + +// Parse parses the SQL in full and returns a Statement, which +// is the AST representation of the query. If a DDL statement +// is partially parsed but still contains a syntax error, the +// error is ignored and the DDL is returned anyway. +func Parse(sql string) (Statement, error) { + tokenizer := NewStringTokenizer(sql) + if yyParsePooled(tokenizer) != 0 { + if tokenizer.partialDDL != nil { + if typ, val := tokenizer.Scan(); typ != 0 { + return nil, fmt.Errorf("extra characters encountered after end of DDL: '%s'", string(val)) + } + log.Warningf("ignoring error parsing DDL '%s': %v", sql, tokenizer.LastError) + tokenizer.ParseTree = tokenizer.partialDDL + return tokenizer.ParseTree, nil + } + return nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, tokenizer.LastError.Error()) + } + if tokenizer.ParseTree == nil { + return nil, ErrEmpty + } + return tokenizer.ParseTree, nil +} + +// ParseStrictDDL is the same as Parse except it errors on +// partially parsed DDL statements. +func ParseStrictDDL(sql string) (Statement, error) { + tokenizer := NewStringTokenizer(sql) + if yyParsePooled(tokenizer) != 0 { + return nil, tokenizer.LastError + } + if tokenizer.ParseTree == nil { + return nil, ErrEmpty + } + return tokenizer.ParseTree, nil +} + +// ParseTokenizer is a raw interface to parse from the given tokenizer. +// This does not used pooled parsers, and should not be used in general. +func ParseTokenizer(tokenizer *Tokenizer) int { + return yyParse(tokenizer) +} + +// ParseNext parses a single SQL statement from the tokenizer +// returning a Statement which is the AST representation of the query. +// The tokenizer will always read up to the end of the statement, allowing for +// the next call to ParseNext to parse any subsequent SQL statements. When +// there are no more statements to parse, a error of io.EOF is returned. +func ParseNext(tokenizer *Tokenizer) (Statement, error) { + return parseNext(tokenizer, false) +} + +// ParseNextStrictDDL is the same as ParseNext except it errors on +// partially parsed DDL statements. +func ParseNextStrictDDL(tokenizer *Tokenizer) (Statement, error) { + return parseNext(tokenizer, true) +} + +func parseNext(tokenizer *Tokenizer, strict bool) (Statement, error) { + if tokenizer.lastChar == ';' { + tokenizer.next() + tokenizer.skipBlank() + } + if tokenizer.lastChar == eofChar { + return nil, io.EOF + } + + tokenizer.reset() + tokenizer.multi = true + if yyParsePooled(tokenizer) != 0 { + if tokenizer.partialDDL != nil && !strict { + tokenizer.ParseTree = tokenizer.partialDDL + return tokenizer.ParseTree, nil + } + return nil, tokenizer.LastError + } + if tokenizer.ParseTree == nil { + return ParseNext(tokenizer) + } + return tokenizer.ParseTree, nil +} + +// ErrEmpty is a sentinel error returned when parsing empty statements. +var ErrEmpty = errors.New("empty statement") + +// SplitStatement returns the first sql statement up to either a ; or EOF +// and the remainder from the given buffer +func SplitStatement(blob string) (string, string, error) { + tokenizer := NewStringTokenizer(blob) + tkn := 0 + for { + tkn, _ = tokenizer.Scan() + if tkn == 0 || tkn == ';' || tkn == eofChar { + break + } + } + if tokenizer.LastError != nil { + return "", "", tokenizer.LastError + } + if tkn == ';' { + return blob[:tokenizer.Position-2], blob[tokenizer.Position-1:], nil + } + return blob, "", nil +} + +// SplitStatementToPieces split raw sql statement that may have multi sql pieces to sql pieces +// returns the sql pieces blob contains; or error if sql cannot be parsed +func SplitStatementToPieces(blob string) (pieces []string, err error) { + pieces = make([]string, 0, 16) + tokenizer := NewStringTokenizer(blob) + + tkn := 0 + var stmt string + stmtBegin := 0 + for { + tkn, _ = tokenizer.Scan() + if tkn == ';' { + stmt = blob[stmtBegin : tokenizer.Position-2] + pieces = append(pieces, stmt) + stmtBegin = tokenizer.Position - 1 + + } else if tkn == 0 || tkn == eofChar { + blobTail := tokenizer.Position - 2 + + if stmtBegin < blobTail { + stmt = blob[stmtBegin : blobTail+1] + pieces = append(pieces, stmt) + } + break + } + } + + err = tokenizer.LastError + return +} + +// String returns a string representation of an SQLNode. +func String(node SQLNode) string { + if node == nil { + return "" + } + + buf := NewTrackedBuffer(nil) + buf.Myprintf("%v", node) + return buf.String() +} diff --git a/go/vt/sqlparser/rewriter.go b/go/vt/sqlparser/rewriter.go new file mode 100644 index 00000000000..28d5ba52688 --- /dev/null +++ b/go/vt/sqlparser/rewriter.go @@ -0,0 +1,1320 @@ +// Code generated by visitorgen/main/main.go. DO NOT EDIT. + +package sqlparser + +//go:generate make visitor + +import ( + "reflect" +) + +type replacerFunc func(newNode, parent SQLNode) + +// application carries all the shared data so we can pass it around cheaply. +type application struct { + pre, post ApplyFunc + cursor Cursor +} + +func replaceAliasedExprAs(newNode, parent SQLNode) { + parent.(*AliasedExpr).As = newNode.(ColIdent) +} + +func replaceAliasedExprExpr(newNode, parent SQLNode) { + parent.(*AliasedExpr).Expr = newNode.(Expr) +} + +func replaceAliasedTableExprAs(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).As = newNode.(TableIdent) +} + +func replaceAliasedTableExprExpr(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Expr = newNode.(SimpleTableExpr) +} + +func replaceAliasedTableExprHints(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Hints = newNode.(*IndexHints) +} + +func replaceAliasedTableExprPartitions(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Partitions = newNode.(Partitions) +} + +func replaceAndExprLeft(newNode, parent SQLNode) { + parent.(*AndExpr).Left = newNode.(Expr) +} + +func replaceAndExprRight(newNode, parent SQLNode) { + parent.(*AndExpr).Right = newNode.(Expr) +} + +func replaceAutoIncSpecColumn(newNode, parent SQLNode) { + parent.(*AutoIncSpec).Column = newNode.(ColIdent) +} + +func replaceAutoIncSpecSequence(newNode, parent SQLNode) { + parent.(*AutoIncSpec).Sequence = newNode.(TableName) +} + +func replaceBinaryExprLeft(newNode, parent SQLNode) { + parent.(*BinaryExpr).Left = newNode.(Expr) +} + +func replaceBinaryExprRight(newNode, parent SQLNode) { + parent.(*BinaryExpr).Right = newNode.(Expr) +} + +func replaceCaseExprElse(newNode, parent SQLNode) { + parent.(*CaseExpr).Else = newNode.(Expr) +} + +func replaceCaseExprExpr(newNode, parent SQLNode) { + parent.(*CaseExpr).Expr = newNode.(Expr) +} + +type replaceCaseExprWhens int + +func (r *replaceCaseExprWhens) replace(newNode, container SQLNode) { + container.(*CaseExpr).Whens[int(*r)] = newNode.(*When) +} + +func (r *replaceCaseExprWhens) inc() { + *r++ +} + +func replaceColNameName(newNode, parent SQLNode) { + parent.(*ColName).Name = newNode.(ColIdent) +} + +func replaceColNameQualifier(newNode, parent SQLNode) { + parent.(*ColName).Qualifier = newNode.(TableName) +} + +func replaceCollateExprExpr(newNode, parent SQLNode) { + parent.(*CollateExpr).Expr = newNode.(Expr) +} + +func replaceColumnDefinitionName(newNode, parent SQLNode) { + parent.(*ColumnDefinition).Name = newNode.(ColIdent) +} + +func replaceColumnTypeAutoincrement(newNode, parent SQLNode) { + parent.(*ColumnType).Autoincrement = newNode.(BoolVal) +} + +func replaceColumnTypeComment(newNode, parent SQLNode) { + parent.(*ColumnType).Comment = newNode.(*SQLVal) +} + +func replaceColumnTypeDefault(newNode, parent SQLNode) { + parent.(*ColumnType).Default = newNode.(Expr) +} + +func replaceColumnTypeLength(newNode, parent SQLNode) { + parent.(*ColumnType).Length = newNode.(*SQLVal) +} + +func replaceColumnTypeNotNull(newNode, parent SQLNode) { + parent.(*ColumnType).NotNull = newNode.(BoolVal) +} + +func replaceColumnTypeOnUpdate(newNode, parent SQLNode) { + parent.(*ColumnType).OnUpdate = newNode.(Expr) +} + +func replaceColumnTypeScale(newNode, parent SQLNode) { + parent.(*ColumnType).Scale = newNode.(*SQLVal) +} + +func replaceColumnTypeUnsigned(newNode, parent SQLNode) { + parent.(*ColumnType).Unsigned = newNode.(BoolVal) +} + +func replaceColumnTypeZerofill(newNode, parent SQLNode) { + parent.(*ColumnType).Zerofill = newNode.(BoolVal) +} + +type replaceColumnsItems int + +func (r *replaceColumnsItems) replace(newNode, container SQLNode) { + container.(Columns)[int(*r)] = newNode.(ColIdent) +} + +func (r *replaceColumnsItems) inc() { + *r++ +} + +func replaceComparisonExprEscape(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Escape = newNode.(Expr) +} + +func replaceComparisonExprLeft(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Left = newNode.(Expr) +} + +func replaceComparisonExprRight(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Right = newNode.(Expr) +} + +func replaceConstraintDefinitionDetails(newNode, parent SQLNode) { + parent.(*ConstraintDefinition).Details = newNode.(ConstraintInfo) +} + +func replaceConvertExprExpr(newNode, parent SQLNode) { + parent.(*ConvertExpr).Expr = newNode.(Expr) +} + +func replaceConvertExprType(newNode, parent SQLNode) { + parent.(*ConvertExpr).Type = newNode.(*ConvertType) +} + +func replaceConvertTypeLength(newNode, parent SQLNode) { + parent.(*ConvertType).Length = newNode.(*SQLVal) +} + +func replaceConvertTypeScale(newNode, parent SQLNode) { + parent.(*ConvertType).Scale = newNode.(*SQLVal) +} + +func replaceConvertUsingExprExpr(newNode, parent SQLNode) { + parent.(*ConvertUsingExpr).Expr = newNode.(Expr) +} + +func replaceCurTimeFuncExprFsp(newNode, parent SQLNode) { + parent.(*CurTimeFuncExpr).Fsp = newNode.(Expr) +} + +func replaceCurTimeFuncExprName(newNode, parent SQLNode) { + parent.(*CurTimeFuncExpr).Name = newNode.(ColIdent) +} + +func replaceDDLAutoIncSpec(newNode, parent SQLNode) { + parent.(*DDL).AutoIncSpec = newNode.(*AutoIncSpec) +} + +func replaceDDLFromTables(newNode, parent SQLNode) { + parent.(*DDL).FromTables = newNode.(TableNames) +} + +func replaceDDLOptLike(newNode, parent SQLNode) { + parent.(*DDL).OptLike = newNode.(*OptLike) +} + +func replaceDDLPartitionSpec(newNode, parent SQLNode) { + parent.(*DDL).PartitionSpec = newNode.(*PartitionSpec) +} + +func replaceDDLTable(newNode, parent SQLNode) { + parent.(*DDL).Table = newNode.(TableName) +} + +func replaceDDLTableSpec(newNode, parent SQLNode) { + parent.(*DDL).TableSpec = newNode.(*TableSpec) +} + +func replaceDDLToTables(newNode, parent SQLNode) { + parent.(*DDL).ToTables = newNode.(TableNames) +} + +type replaceDDLVindexCols int + +func (r *replaceDDLVindexCols) replace(newNode, container SQLNode) { + container.(*DDL).VindexCols[int(*r)] = newNode.(ColIdent) +} + +func (r *replaceDDLVindexCols) inc() { + *r++ +} + +func replaceDDLVindexSpec(newNode, parent SQLNode) { + parent.(*DDL).VindexSpec = newNode.(*VindexSpec) +} + +func replaceDeleteComments(newNode, parent SQLNode) { + parent.(*Delete).Comments = newNode.(Comments) +} + +func replaceDeleteLimit(newNode, parent SQLNode) { + parent.(*Delete).Limit = newNode.(*Limit) +} + +func replaceDeleteOrderBy(newNode, parent SQLNode) { + parent.(*Delete).OrderBy = newNode.(OrderBy) +} + +func replaceDeletePartitions(newNode, parent SQLNode) { + parent.(*Delete).Partitions = newNode.(Partitions) +} + +func replaceDeleteTableExprs(newNode, parent SQLNode) { + parent.(*Delete).TableExprs = newNode.(TableExprs) +} + +func replaceDeleteTargets(newNode, parent SQLNode) { + parent.(*Delete).Targets = newNode.(TableNames) +} + +func replaceDeleteWhere(newNode, parent SQLNode) { + parent.(*Delete).Where = newNode.(*Where) +} + +func replaceExistsExprSubquery(newNode, parent SQLNode) { + parent.(*ExistsExpr).Subquery = newNode.(*Subquery) +} + +type replaceExprsItems int + +func (r *replaceExprsItems) replace(newNode, container SQLNode) { + container.(Exprs)[int(*r)] = newNode.(Expr) +} + +func (r *replaceExprsItems) inc() { + *r++ +} + +func replaceForeignKeyDefinitionOnDelete(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).OnDelete = newNode.(ReferenceAction) +} + +func replaceForeignKeyDefinitionOnUpdate(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).OnUpdate = newNode.(ReferenceAction) +} + +func replaceForeignKeyDefinitionReferencedColumns(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).ReferencedColumns = newNode.(Columns) +} + +func replaceForeignKeyDefinitionReferencedTable(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).ReferencedTable = newNode.(TableName) +} + +func replaceForeignKeyDefinitionSource(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).Source = newNode.(Columns) +} + +func replaceFuncExprExprs(newNode, parent SQLNode) { + parent.(*FuncExpr).Exprs = newNode.(SelectExprs) +} + +func replaceFuncExprName(newNode, parent SQLNode) { + parent.(*FuncExpr).Name = newNode.(ColIdent) +} + +func replaceFuncExprQualifier(newNode, parent SQLNode) { + parent.(*FuncExpr).Qualifier = newNode.(TableIdent) +} + +type replaceGroupByItems int + +func (r *replaceGroupByItems) replace(newNode, container SQLNode) { + container.(GroupBy)[int(*r)] = newNode.(Expr) +} + +func (r *replaceGroupByItems) inc() { + *r++ +} + +func replaceGroupConcatExprExprs(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).Exprs = newNode.(SelectExprs) +} + +func replaceGroupConcatExprOrderBy(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).OrderBy = newNode.(OrderBy) +} + +func replaceIndexDefinitionInfo(newNode, parent SQLNode) { + parent.(*IndexDefinition).Info = newNode.(*IndexInfo) +} + +type replaceIndexHintsIndexes int + +func (r *replaceIndexHintsIndexes) replace(newNode, container SQLNode) { + container.(*IndexHints).Indexes[int(*r)] = newNode.(ColIdent) +} + +func (r *replaceIndexHintsIndexes) inc() { + *r++ +} + +func replaceIndexInfoName(newNode, parent SQLNode) { + parent.(*IndexInfo).Name = newNode.(ColIdent) +} + +func replaceInsertColumns(newNode, parent SQLNode) { + parent.(*Insert).Columns = newNode.(Columns) +} + +func replaceInsertComments(newNode, parent SQLNode) { + parent.(*Insert).Comments = newNode.(Comments) +} + +func replaceInsertOnDup(newNode, parent SQLNode) { + parent.(*Insert).OnDup = newNode.(OnDup) +} + +func replaceInsertPartitions(newNode, parent SQLNode) { + parent.(*Insert).Partitions = newNode.(Partitions) +} + +func replaceInsertRows(newNode, parent SQLNode) { + parent.(*Insert).Rows = newNode.(InsertRows) +} + +func replaceInsertTable(newNode, parent SQLNode) { + parent.(*Insert).Table = newNode.(TableName) +} + +func replaceIntervalExprExpr(newNode, parent SQLNode) { + parent.(*IntervalExpr).Expr = newNode.(Expr) +} + +func replaceIsExprExpr(newNode, parent SQLNode) { + parent.(*IsExpr).Expr = newNode.(Expr) +} + +func replaceJoinConditionOn(newNode, parent SQLNode) { + tmp := parent.(JoinCondition) + tmp.On = newNode.(Expr) +} + +func replaceJoinConditionUsing(newNode, parent SQLNode) { + tmp := parent.(JoinCondition) + tmp.Using = newNode.(Columns) +} + +func replaceJoinTableExprCondition(newNode, parent SQLNode) { + parent.(*JoinTableExpr).Condition = newNode.(JoinCondition) +} + +func replaceJoinTableExprLeftExpr(newNode, parent SQLNode) { + parent.(*JoinTableExpr).LeftExpr = newNode.(TableExpr) +} + +func replaceJoinTableExprRightExpr(newNode, parent SQLNode) { + parent.(*JoinTableExpr).RightExpr = newNode.(TableExpr) +} + +func replaceLimitOffset(newNode, parent SQLNode) { + parent.(*Limit).Offset = newNode.(Expr) +} + +func replaceLimitRowcount(newNode, parent SQLNode) { + parent.(*Limit).Rowcount = newNode.(Expr) +} + +func replaceMatchExprColumns(newNode, parent SQLNode) { + parent.(*MatchExpr).Columns = newNode.(SelectExprs) +} + +func replaceMatchExprExpr(newNode, parent SQLNode) { + parent.(*MatchExpr).Expr = newNode.(Expr) +} + +func replaceNextvalExpr(newNode, parent SQLNode) { + tmp := parent.(Nextval) + tmp.Expr = newNode.(Expr) +} + +func replaceNotExprExpr(newNode, parent SQLNode) { + parent.(*NotExpr).Expr = newNode.(Expr) +} + +type replaceOnDupItems int + +func (r *replaceOnDupItems) replace(newNode, container SQLNode) { + container.(OnDup)[int(*r)] = newNode.(*UpdateExpr) +} + +func (r *replaceOnDupItems) inc() { + *r++ +} + +func replaceOptLikeLikeTable(newNode, parent SQLNode) { + parent.(*OptLike).LikeTable = newNode.(TableName) +} + +func replaceOrExprLeft(newNode, parent SQLNode) { + parent.(*OrExpr).Left = newNode.(Expr) +} + +func replaceOrExprRight(newNode, parent SQLNode) { + parent.(*OrExpr).Right = newNode.(Expr) +} + +func replaceOrderExpr(newNode, parent SQLNode) { + parent.(*Order).Expr = newNode.(Expr) +} + +type replaceOrderByItems int + +func (r *replaceOrderByItems) replace(newNode, container SQLNode) { + container.(OrderBy)[int(*r)] = newNode.(*Order) +} + +func (r *replaceOrderByItems) inc() { + *r++ +} + +func replaceParenExprExpr(newNode, parent SQLNode) { + parent.(*ParenExpr).Expr = newNode.(Expr) +} + +func replaceParenSelectSelect(newNode, parent SQLNode) { + parent.(*ParenSelect).Select = newNode.(SelectStatement) +} + +func replaceParenTableExprExprs(newNode, parent SQLNode) { + parent.(*ParenTableExpr).Exprs = newNode.(TableExprs) +} + +func replacePartitionDefinitionLimit(newNode, parent SQLNode) { + parent.(*PartitionDefinition).Limit = newNode.(Expr) +} + +func replacePartitionDefinitionName(newNode, parent SQLNode) { + parent.(*PartitionDefinition).Name = newNode.(ColIdent) +} + +type replacePartitionSpecDefinitions int + +func (r *replacePartitionSpecDefinitions) replace(newNode, container SQLNode) { + container.(*PartitionSpec).Definitions[int(*r)] = newNode.(*PartitionDefinition) +} + +func (r *replacePartitionSpecDefinitions) inc() { + *r++ +} + +func replacePartitionSpecName(newNode, parent SQLNode) { + parent.(*PartitionSpec).Name = newNode.(ColIdent) +} + +type replacePartitionsItems int + +func (r *replacePartitionsItems) replace(newNode, container SQLNode) { + container.(Partitions)[int(*r)] = newNode.(ColIdent) +} + +func (r *replacePartitionsItems) inc() { + *r++ +} + +func replaceRangeCondFrom(newNode, parent SQLNode) { + parent.(*RangeCond).From = newNode.(Expr) +} + +func replaceRangeCondLeft(newNode, parent SQLNode) { + parent.(*RangeCond).Left = newNode.(Expr) +} + +func replaceRangeCondTo(newNode, parent SQLNode) { + parent.(*RangeCond).To = newNode.(Expr) +} + +func replaceSelectComments(newNode, parent SQLNode) { + parent.(*Select).Comments = newNode.(Comments) +} + +func replaceSelectFrom(newNode, parent SQLNode) { + parent.(*Select).From = newNode.(TableExprs) +} + +func replaceSelectGroupBy(newNode, parent SQLNode) { + parent.(*Select).GroupBy = newNode.(GroupBy) +} + +func replaceSelectHaving(newNode, parent SQLNode) { + parent.(*Select).Having = newNode.(*Where) +} + +func replaceSelectLimit(newNode, parent SQLNode) { + parent.(*Select).Limit = newNode.(*Limit) +} + +func replaceSelectOrderBy(newNode, parent SQLNode) { + parent.(*Select).OrderBy = newNode.(OrderBy) +} + +func replaceSelectSelectExprs(newNode, parent SQLNode) { + parent.(*Select).SelectExprs = newNode.(SelectExprs) +} + +func replaceSelectWhere(newNode, parent SQLNode) { + parent.(*Select).Where = newNode.(*Where) +} + +type replaceSelectExprsItems int + +func (r *replaceSelectExprsItems) replace(newNode, container SQLNode) { + container.(SelectExprs)[int(*r)] = newNode.(SelectExpr) +} + +func (r *replaceSelectExprsItems) inc() { + *r++ +} + +func replaceSetComments(newNode, parent SQLNode) { + parent.(*Set).Comments = newNode.(Comments) +} + +func replaceSetExprs(newNode, parent SQLNode) { + parent.(*Set).Exprs = newNode.(SetExprs) +} + +func replaceSetExprExpr(newNode, parent SQLNode) { + parent.(*SetExpr).Expr = newNode.(Expr) +} + +func replaceSetExprName(newNode, parent SQLNode) { + parent.(*SetExpr).Name = newNode.(ColIdent) +} + +type replaceSetExprsItems int + +func (r *replaceSetExprsItems) replace(newNode, container SQLNode) { + container.(SetExprs)[int(*r)] = newNode.(*SetExpr) +} + +func (r *replaceSetExprsItems) inc() { + *r++ +} + +func replaceShowOnTable(newNode, parent SQLNode) { + parent.(*Show).OnTable = newNode.(TableName) +} + +func replaceShowTable(newNode, parent SQLNode) { + parent.(*Show).Table = newNode.(TableName) +} + +func replaceShowFilterFilter(newNode, parent SQLNode) { + parent.(*ShowFilter).Filter = newNode.(Expr) +} + +func replaceStarExprTableName(newNode, parent SQLNode) { + parent.(*StarExpr).TableName = newNode.(TableName) +} + +func replaceStreamComments(newNode, parent SQLNode) { + parent.(*Stream).Comments = newNode.(Comments) +} + +func replaceStreamSelectExpr(newNode, parent SQLNode) { + parent.(*Stream).SelectExpr = newNode.(SelectExpr) +} + +func replaceStreamTable(newNode, parent SQLNode) { + parent.(*Stream).Table = newNode.(TableName) +} + +func replaceSubquerySelect(newNode, parent SQLNode) { + parent.(*Subquery).Select = newNode.(SelectStatement) +} + +func replaceSubstrExprFrom(newNode, parent SQLNode) { + parent.(*SubstrExpr).From = newNode.(Expr) +} + +func replaceSubstrExprName(newNode, parent SQLNode) { + parent.(*SubstrExpr).Name = newNode.(*ColName) +} + +func replaceSubstrExprStrVal(newNode, parent SQLNode) { + parent.(*SubstrExpr).StrVal = newNode.(*SQLVal) +} + +func replaceSubstrExprTo(newNode, parent SQLNode) { + parent.(*SubstrExpr).To = newNode.(Expr) +} + +type replaceTableExprsItems int + +func (r *replaceTableExprsItems) replace(newNode, container SQLNode) { + container.(TableExprs)[int(*r)] = newNode.(TableExpr) +} + +func (r *replaceTableExprsItems) inc() { + *r++ +} + +func replaceTableNameName(newNode, parent SQLNode) { + tmp := parent.(TableName) + tmp.Name = newNode.(TableIdent) +} + +func replaceTableNameQualifier(newNode, parent SQLNode) { + tmp := parent.(TableName) + tmp.Qualifier = newNode.(TableIdent) +} + +type replaceTableNamesItems int + +func (r *replaceTableNamesItems) replace(newNode, container SQLNode) { + container.(TableNames)[int(*r)] = newNode.(TableName) +} + +func (r *replaceTableNamesItems) inc() { + *r++ +} + +type replaceTableSpecColumns int + +func (r *replaceTableSpecColumns) replace(newNode, container SQLNode) { + container.(*TableSpec).Columns[int(*r)] = newNode.(*ColumnDefinition) +} + +func (r *replaceTableSpecColumns) inc() { + *r++ +} + +type replaceTableSpecConstraints int + +func (r *replaceTableSpecConstraints) replace(newNode, container SQLNode) { + container.(*TableSpec).Constraints[int(*r)] = newNode.(*ConstraintDefinition) +} + +func (r *replaceTableSpecConstraints) inc() { + *r++ +} + +type replaceTableSpecIndexes int + +func (r *replaceTableSpecIndexes) replace(newNode, container SQLNode) { + container.(*TableSpec).Indexes[int(*r)] = newNode.(*IndexDefinition) +} + +func (r *replaceTableSpecIndexes) inc() { + *r++ +} + +func replaceTimestampFuncExprExpr1(newNode, parent SQLNode) { + parent.(*TimestampFuncExpr).Expr1 = newNode.(Expr) +} + +func replaceTimestampFuncExprExpr2(newNode, parent SQLNode) { + parent.(*TimestampFuncExpr).Expr2 = newNode.(Expr) +} + +func replaceUnaryExprExpr(newNode, parent SQLNode) { + parent.(*UnaryExpr).Expr = newNode.(Expr) +} + +func replaceUnionLeft(newNode, parent SQLNode) { + parent.(*Union).Left = newNode.(SelectStatement) +} + +func replaceUnionLimit(newNode, parent SQLNode) { + parent.(*Union).Limit = newNode.(*Limit) +} + +func replaceUnionOrderBy(newNode, parent SQLNode) { + parent.(*Union).OrderBy = newNode.(OrderBy) +} + +func replaceUnionRight(newNode, parent SQLNode) { + parent.(*Union).Right = newNode.(SelectStatement) +} + +func replaceUpdateComments(newNode, parent SQLNode) { + parent.(*Update).Comments = newNode.(Comments) +} + +func replaceUpdateExprs(newNode, parent SQLNode) { + parent.(*Update).Exprs = newNode.(UpdateExprs) +} + +func replaceUpdateLimit(newNode, parent SQLNode) { + parent.(*Update).Limit = newNode.(*Limit) +} + +func replaceUpdateOrderBy(newNode, parent SQLNode) { + parent.(*Update).OrderBy = newNode.(OrderBy) +} + +func replaceUpdateTableExprs(newNode, parent SQLNode) { + parent.(*Update).TableExprs = newNode.(TableExprs) +} + +func replaceUpdateWhere(newNode, parent SQLNode) { + parent.(*Update).Where = newNode.(*Where) +} + +func replaceUpdateExprExpr(newNode, parent SQLNode) { + parent.(*UpdateExpr).Expr = newNode.(Expr) +} + +func replaceUpdateExprName(newNode, parent SQLNode) { + parent.(*UpdateExpr).Name = newNode.(*ColName) +} + +type replaceUpdateExprsItems int + +func (r *replaceUpdateExprsItems) replace(newNode, container SQLNode) { + container.(UpdateExprs)[int(*r)] = newNode.(*UpdateExpr) +} + +func (r *replaceUpdateExprsItems) inc() { + *r++ +} + +func replaceUseDBName(newNode, parent SQLNode) { + parent.(*Use).DBName = newNode.(TableIdent) +} + +type replaceValTupleItems int + +func (r *replaceValTupleItems) replace(newNode, container SQLNode) { + container.(ValTuple)[int(*r)] = newNode.(Expr) +} + +func (r *replaceValTupleItems) inc() { + *r++ +} + +type replaceValuesItems int + +func (r *replaceValuesItems) replace(newNode, container SQLNode) { + container.(Values)[int(*r)] = newNode.(ValTuple) +} + +func (r *replaceValuesItems) inc() { + *r++ +} + +func replaceValuesFuncExprName(newNode, parent SQLNode) { + parent.(*ValuesFuncExpr).Name = newNode.(*ColName) +} + +func replaceVindexParamKey(newNode, parent SQLNode) { + tmp := parent.(VindexParam) + tmp.Key = newNode.(ColIdent) +} + +func replaceVindexSpecName(newNode, parent SQLNode) { + parent.(*VindexSpec).Name = newNode.(ColIdent) +} + +type replaceVindexSpecParams int + +func (r *replaceVindexSpecParams) replace(newNode, container SQLNode) { + container.(*VindexSpec).Params[int(*r)] = newNode.(VindexParam) +} + +func (r *replaceVindexSpecParams) inc() { + *r++ +} + +func replaceVindexSpecType(newNode, parent SQLNode) { + parent.(*VindexSpec).Type = newNode.(ColIdent) +} + +func replaceWhenCond(newNode, parent SQLNode) { + parent.(*When).Cond = newNode.(Expr) +} + +func replaceWhenVal(newNode, parent SQLNode) { + parent.(*When).Val = newNode.(Expr) +} + +func replaceWhereExpr(newNode, parent SQLNode) { + parent.(*Where).Expr = newNode.(Expr) +} + +// apply is where the visiting happens. Here is where we keep the big switch-case that will be used +// to do the actual visiting of SQLNodes +func (a *application) apply(parent, node SQLNode, replacer replacerFunc) { + if node == nil || isNilValue(node) { + return + } + + // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead + saved := a.cursor + a.cursor.replacer = replacer + a.cursor.node = node + a.cursor.parent = parent + + if a.pre != nil && !a.pre(&a.cursor) { + a.cursor = saved + return + } + + // walk children + // (the order of the cases is alphabetical) + switch n := node.(type) { + case nil: + case *AliasedExpr: + a.apply(node, n.As, replaceAliasedExprAs) + a.apply(node, n.Expr, replaceAliasedExprExpr) + + case *AliasedTableExpr: + a.apply(node, n.As, replaceAliasedTableExprAs) + a.apply(node, n.Expr, replaceAliasedTableExprExpr) + a.apply(node, n.Hints, replaceAliasedTableExprHints) + a.apply(node, n.Partitions, replaceAliasedTableExprPartitions) + + case *AndExpr: + a.apply(node, n.Left, replaceAndExprLeft) + a.apply(node, n.Right, replaceAndExprRight) + + case *AutoIncSpec: + a.apply(node, n.Column, replaceAutoIncSpecColumn) + a.apply(node, n.Sequence, replaceAutoIncSpecSequence) + + case *Begin: + + case *BinaryExpr: + a.apply(node, n.Left, replaceBinaryExprLeft) + a.apply(node, n.Right, replaceBinaryExprRight) + + case BoolVal: + + case *CaseExpr: + a.apply(node, n.Else, replaceCaseExprElse) + a.apply(node, n.Expr, replaceCaseExprExpr) + replacerWhens := replaceCaseExprWhens(0) + replacerWhensB := &replacerWhens + for _, item := range n.Whens { + a.apply(node, item, replacerWhensB.replace) + replacerWhensB.inc() + } + + case ColIdent: + + case *ColName: + a.apply(node, n.Name, replaceColNameName) + a.apply(node, n.Qualifier, replaceColNameQualifier) + + case *CollateExpr: + a.apply(node, n.Expr, replaceCollateExprExpr) + + case *ColumnDefinition: + a.apply(node, n.Name, replaceColumnDefinitionName) + + case *ColumnType: + a.apply(node, n.Autoincrement, replaceColumnTypeAutoincrement) + a.apply(node, n.Comment, replaceColumnTypeComment) + a.apply(node, n.Default, replaceColumnTypeDefault) + a.apply(node, n.Length, replaceColumnTypeLength) + a.apply(node, n.NotNull, replaceColumnTypeNotNull) + a.apply(node, n.OnUpdate, replaceColumnTypeOnUpdate) + a.apply(node, n.Scale, replaceColumnTypeScale) + a.apply(node, n.Unsigned, replaceColumnTypeUnsigned) + a.apply(node, n.Zerofill, replaceColumnTypeZerofill) + + case Columns: + replacer := replaceColumnsItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case Comments: + + case *Commit: + + case *ComparisonExpr: + a.apply(node, n.Escape, replaceComparisonExprEscape) + a.apply(node, n.Left, replaceComparisonExprLeft) + a.apply(node, n.Right, replaceComparisonExprRight) + + case *ConstraintDefinition: + a.apply(node, n.Details, replaceConstraintDefinitionDetails) + + case *ConvertExpr: + a.apply(node, n.Expr, replaceConvertExprExpr) + a.apply(node, n.Type, replaceConvertExprType) + + case *ConvertType: + a.apply(node, n.Length, replaceConvertTypeLength) + a.apply(node, n.Scale, replaceConvertTypeScale) + + case *ConvertUsingExpr: + a.apply(node, n.Expr, replaceConvertUsingExprExpr) + + case *CurTimeFuncExpr: + a.apply(node, n.Fsp, replaceCurTimeFuncExprFsp) + a.apply(node, n.Name, replaceCurTimeFuncExprName) + + case *DBDDL: + + case *DDL: + a.apply(node, n.AutoIncSpec, replaceDDLAutoIncSpec) + a.apply(node, n.FromTables, replaceDDLFromTables) + a.apply(node, n.OptLike, replaceDDLOptLike) + a.apply(node, n.PartitionSpec, replaceDDLPartitionSpec) + a.apply(node, n.Table, replaceDDLTable) + a.apply(node, n.TableSpec, replaceDDLTableSpec) + a.apply(node, n.ToTables, replaceDDLToTables) + replacerVindexCols := replaceDDLVindexCols(0) + replacerVindexColsB := &replacerVindexCols + for _, item := range n.VindexCols { + a.apply(node, item, replacerVindexColsB.replace) + replacerVindexColsB.inc() + } + a.apply(node, n.VindexSpec, replaceDDLVindexSpec) + + case *Default: + + case *Delete: + a.apply(node, n.Comments, replaceDeleteComments) + a.apply(node, n.Limit, replaceDeleteLimit) + a.apply(node, n.OrderBy, replaceDeleteOrderBy) + a.apply(node, n.Partitions, replaceDeletePartitions) + a.apply(node, n.TableExprs, replaceDeleteTableExprs) + a.apply(node, n.Targets, replaceDeleteTargets) + a.apply(node, n.Where, replaceDeleteWhere) + + case *ExistsExpr: + a.apply(node, n.Subquery, replaceExistsExprSubquery) + + case Exprs: + replacer := replaceExprsItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case *ForeignKeyDefinition: + a.apply(node, n.OnDelete, replaceForeignKeyDefinitionOnDelete) + a.apply(node, n.OnUpdate, replaceForeignKeyDefinitionOnUpdate) + a.apply(node, n.ReferencedColumns, replaceForeignKeyDefinitionReferencedColumns) + a.apply(node, n.ReferencedTable, replaceForeignKeyDefinitionReferencedTable) + a.apply(node, n.Source, replaceForeignKeyDefinitionSource) + + case *FuncExpr: + a.apply(node, n.Exprs, replaceFuncExprExprs) + a.apply(node, n.Name, replaceFuncExprName) + a.apply(node, n.Qualifier, replaceFuncExprQualifier) + + case GroupBy: + replacer := replaceGroupByItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case *GroupConcatExpr: + a.apply(node, n.Exprs, replaceGroupConcatExprExprs) + a.apply(node, n.OrderBy, replaceGroupConcatExprOrderBy) + + case *IndexDefinition: + a.apply(node, n.Info, replaceIndexDefinitionInfo) + + case *IndexHints: + replacerIndexes := replaceIndexHintsIndexes(0) + replacerIndexesB := &replacerIndexes + for _, item := range n.Indexes { + a.apply(node, item, replacerIndexesB.replace) + replacerIndexesB.inc() + } + + case *IndexInfo: + a.apply(node, n.Name, replaceIndexInfoName) + + case *Insert: + a.apply(node, n.Columns, replaceInsertColumns) + a.apply(node, n.Comments, replaceInsertComments) + a.apply(node, n.OnDup, replaceInsertOnDup) + a.apply(node, n.Partitions, replaceInsertPartitions) + a.apply(node, n.Rows, replaceInsertRows) + a.apply(node, n.Table, replaceInsertTable) + + case *IntervalExpr: + a.apply(node, n.Expr, replaceIntervalExprExpr) + + case *IsExpr: + a.apply(node, n.Expr, replaceIsExprExpr) + + case JoinCondition: + a.apply(node, n.On, replaceJoinConditionOn) + a.apply(node, n.Using, replaceJoinConditionUsing) + + case *JoinTableExpr: + a.apply(node, n.Condition, replaceJoinTableExprCondition) + a.apply(node, n.LeftExpr, replaceJoinTableExprLeftExpr) + a.apply(node, n.RightExpr, replaceJoinTableExprRightExpr) + + case *Limit: + a.apply(node, n.Offset, replaceLimitOffset) + a.apply(node, n.Rowcount, replaceLimitRowcount) + + case ListArg: + + case *MatchExpr: + a.apply(node, n.Columns, replaceMatchExprColumns) + a.apply(node, n.Expr, replaceMatchExprExpr) + + case Nextval: + a.apply(node, n.Expr, replaceNextvalExpr) + + case *NotExpr: + a.apply(node, n.Expr, replaceNotExprExpr) + + case *NullVal: + + case OnDup: + replacer := replaceOnDupItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case *OptLike: + a.apply(node, n.LikeTable, replaceOptLikeLikeTable) + + case *OrExpr: + a.apply(node, n.Left, replaceOrExprLeft) + a.apply(node, n.Right, replaceOrExprRight) + + case *Order: + a.apply(node, n.Expr, replaceOrderExpr) + + case OrderBy: + replacer := replaceOrderByItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case *OtherAdmin: + + case *OtherRead: + + case *ParenExpr: + a.apply(node, n.Expr, replaceParenExprExpr) + + case *ParenSelect: + a.apply(node, n.Select, replaceParenSelectSelect) + + case *ParenTableExpr: + a.apply(node, n.Exprs, replaceParenTableExprExprs) + + case *PartitionDefinition: + a.apply(node, n.Limit, replacePartitionDefinitionLimit) + a.apply(node, n.Name, replacePartitionDefinitionName) + + case *PartitionSpec: + replacerDefinitions := replacePartitionSpecDefinitions(0) + replacerDefinitionsB := &replacerDefinitions + for _, item := range n.Definitions { + a.apply(node, item, replacerDefinitionsB.replace) + replacerDefinitionsB.inc() + } + a.apply(node, n.Name, replacePartitionSpecName) + + case Partitions: + replacer := replacePartitionsItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case *RangeCond: + a.apply(node, n.From, replaceRangeCondFrom) + a.apply(node, n.Left, replaceRangeCondLeft) + a.apply(node, n.To, replaceRangeCondTo) + + case ReferenceAction: + + case *Rollback: + + case *SQLVal: + + case *Select: + a.apply(node, n.Comments, replaceSelectComments) + a.apply(node, n.From, replaceSelectFrom) + a.apply(node, n.GroupBy, replaceSelectGroupBy) + a.apply(node, n.Having, replaceSelectHaving) + a.apply(node, n.Limit, replaceSelectLimit) + a.apply(node, n.OrderBy, replaceSelectOrderBy) + a.apply(node, n.SelectExprs, replaceSelectSelectExprs) + a.apply(node, n.Where, replaceSelectWhere) + + case SelectExprs: + replacer := replaceSelectExprsItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case *Set: + a.apply(node, n.Comments, replaceSetComments) + a.apply(node, n.Exprs, replaceSetExprs) + + case *SetExpr: + a.apply(node, n.Expr, replaceSetExprExpr) + a.apply(node, n.Name, replaceSetExprName) + + case SetExprs: + replacer := replaceSetExprsItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case *Show: + a.apply(node, n.OnTable, replaceShowOnTable) + a.apply(node, n.Table, replaceShowTable) + + case *ShowFilter: + a.apply(node, n.Filter, replaceShowFilterFilter) + + case *StarExpr: + a.apply(node, n.TableName, replaceStarExprTableName) + + case *Stream: + a.apply(node, n.Comments, replaceStreamComments) + a.apply(node, n.SelectExpr, replaceStreamSelectExpr) + a.apply(node, n.Table, replaceStreamTable) + + case *Subquery: + a.apply(node, n.Select, replaceSubquerySelect) + + case *SubstrExpr: + a.apply(node, n.From, replaceSubstrExprFrom) + a.apply(node, n.Name, replaceSubstrExprName) + a.apply(node, n.StrVal, replaceSubstrExprStrVal) + a.apply(node, n.To, replaceSubstrExprTo) + + case TableExprs: + replacer := replaceTableExprsItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case TableIdent: + + case TableName: + a.apply(node, n.Name, replaceTableNameName) + a.apply(node, n.Qualifier, replaceTableNameQualifier) + + case TableNames: + replacer := replaceTableNamesItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case *TableSpec: + replacerColumns := replaceTableSpecColumns(0) + replacerColumnsB := &replacerColumns + for _, item := range n.Columns { + a.apply(node, item, replacerColumnsB.replace) + replacerColumnsB.inc() + } + replacerConstraints := replaceTableSpecConstraints(0) + replacerConstraintsB := &replacerConstraints + for _, item := range n.Constraints { + a.apply(node, item, replacerConstraintsB.replace) + replacerConstraintsB.inc() + } + replacerIndexes := replaceTableSpecIndexes(0) + replacerIndexesB := &replacerIndexes + for _, item := range n.Indexes { + a.apply(node, item, replacerIndexesB.replace) + replacerIndexesB.inc() + } + + case *TimestampFuncExpr: + a.apply(node, n.Expr1, replaceTimestampFuncExprExpr1) + a.apply(node, n.Expr2, replaceTimestampFuncExprExpr2) + + case *UnaryExpr: + a.apply(node, n.Expr, replaceUnaryExprExpr) + + case *Union: + a.apply(node, n.Left, replaceUnionLeft) + a.apply(node, n.Limit, replaceUnionLimit) + a.apply(node, n.OrderBy, replaceUnionOrderBy) + a.apply(node, n.Right, replaceUnionRight) + + case *Update: + a.apply(node, n.Comments, replaceUpdateComments) + a.apply(node, n.Exprs, replaceUpdateExprs) + a.apply(node, n.Limit, replaceUpdateLimit) + a.apply(node, n.OrderBy, replaceUpdateOrderBy) + a.apply(node, n.TableExprs, replaceUpdateTableExprs) + a.apply(node, n.Where, replaceUpdateWhere) + + case *UpdateExpr: + a.apply(node, n.Expr, replaceUpdateExprExpr) + a.apply(node, n.Name, replaceUpdateExprName) + + case UpdateExprs: + replacer := replaceUpdateExprsItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case *Use: + a.apply(node, n.DBName, replaceUseDBName) + + case ValTuple: + replacer := replaceValTupleItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case Values: + replacer := replaceValuesItems(0) + replacerRef := &replacer + for _, item := range n { + a.apply(node, item, replacerRef.replace) + replacerRef.inc() + } + + case *ValuesFuncExpr: + a.apply(node, n.Name, replaceValuesFuncExprName) + + case VindexParam: + a.apply(node, n.Key, replaceVindexParamKey) + + case *VindexSpec: + a.apply(node, n.Name, replaceVindexSpecName) + replacerParams := replaceVindexSpecParams(0) + replacerParamsB := &replacerParams + for _, item := range n.Params { + a.apply(node, item, replacerParamsB.replace) + replacerParamsB.inc() + } + a.apply(node, n.Type, replaceVindexSpecType) + + case *When: + a.apply(node, n.Cond, replaceWhenCond) + a.apply(node, n.Val, replaceWhenVal) + + case *Where: + a.apply(node, n.Expr, replaceWhereExpr) + + default: + panic("unknown ast type " + reflect.TypeOf(node).String()) + } + + if a.post != nil && !a.post(&a.cursor) { + panic(abort) + } + + a.cursor = saved +} + +func isNilValue(i interface{}) bool { + valueOf := reflect.ValueOf(i) + kind := valueOf.Kind() + isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice + return isNullable && valueOf.IsNil() +} diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go new file mode 100644 index 00000000000..c5732d1aceb --- /dev/null +++ b/go/vt/sqlparser/rewriter_api.go @@ -0,0 +1,91 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +// The rewriter was heavily inspired by https://github.com/golang/tools/blob/master/go/ast/astutil/rewrite.go + +// Rewrite traverses a syntax tree recursively, starting with root, +// and calling pre and post for each node as described below. +// Rewrite returns the syntax tree, possibly modified. +// +// If pre is not nil, it is called for each node before the node's +// children are traversed (pre-order). If pre returns false, no +// children are traversed, and post is not called for that node. +// +// If post is not nil, and a prior call of pre didn't return false, +// post is called for each node after its children are traversed +// (post-order). If post returns false, traversal is terminated and +// Apply returns immediately. +// +// Only fields that refer to AST nodes are considered children; +// i.e., fields of basic types (strings, []byte, etc.) are ignored. +// +func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode) { + parent := &struct{ SQLNode }{node} + defer func() { + if r := recover(); r != nil && r != abort { + panic(r) + } + result = parent.SQLNode + }() + + a := &application{ + pre: pre, + post: post, + cursor: Cursor{}, + } + + // this is the root-replacer, used when the user replaces the root of the ast + replacer := func(newNode SQLNode, _ SQLNode) { + parent.SQLNode = newNode + } + + a.apply(parent, node, replacer) + + return parent.SQLNode +} + +// An ApplyFunc is invoked by Rewrite for each node n, even if n is nil, +// before and/or after the node's children, using a Cursor describing +// the current node and providing operations on it. +// +// The return value of ApplyFunc controls the syntax tree traversal. +// See Rewrite for details. +type ApplyFunc func(*Cursor) bool + +var abort = new(int) // singleton, to signal termination of Apply + +// A Cursor describes a node encountered during Apply. +// Information about the node and its parent is available +// from the Node and Parent methods. +type Cursor struct { + parent SQLNode + replacer replacerFunc + node SQLNode +} + +// Node returns the current Node. +func (c *Cursor) Node() SQLNode { return c.node } + +// Parent returns the parent of the current Node. +func (c *Cursor) Parent() SQLNode { return c.parent } + +// Replace replaces the current node in the parent field with this new object. The use needs to make sure to not +// replace the object with something of the wrong type, or the visitor will panic. +func (c *Cursor) Replace(newNode SQLNode) { + c.replacer(newNode, c.parent) +} diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index 69d57b00320..a62bac5942f 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -37,10 +37,6 @@ const ( // This is used for sending different IN clause values // to different shards. ListVarName = "__vals" - //LastInsertIDName is a reserved bind var name for last_insert_id() - LastInsertIDName = "__lastInsertId" - //DBVarName is a reserved bind var name for database() - DBVarName = "__vtdbname" ) // VCursor defines the interface the engine will use diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index c3aff0561df..2c3ae248764 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -274,9 +274,24 @@ func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql if err != nil { return nil, err } - sqlparser.Normalize(stmt, bindVars, "vtg") - normalized := sqlparser.String(stmt) + rewriteResult, err := sqlparser.PrepareAST(stmt, bindVars, "vtg") + if err != nil { + return nil, err + } + normalized := sqlparser.String(rewriteResult.AST) sql = comments.Leading + normalized + comments.Trailing + if rewriteResult.NeedDatabase { + keyspace, _, _, _ := e.ParseDestinationTarget(safeSession.TargetString) + log.Warningf("This is the keyspace name: ---> %v", keyspace) + if keyspace == "" { + bindVars[sqlparser.DBVarName] = sqltypes.NullBindVariable + } else { + bindVars[sqlparser.DBVarName] = sqltypes.StringBindVariable(keyspace) + } + } + if rewriteResult.NeedLastInsertID { + bindVars[sqlparser.LastInsertIDName] = sqltypes.Uint64BindVariable(safeSession.GetLastInsertId()) + } } logStats.PlanTime = execStart.Sub(logStats.StartTime) logStats.SQL = sql @@ -307,14 +322,14 @@ func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql } if plan.NeedsLastInsertID { - bindVars[engine.LastInsertIDName] = sqltypes.Uint64BindVariable(safeSession.GetLastInsertId()) + bindVars[sqlparser.LastInsertIDName] = sqltypes.Uint64BindVariable(safeSession.GetLastInsertId()) } if plan.NeedsDatabaseName { keyspace, _, _, _ := e.ParseDestinationTarget(safeSession.TargetString) if keyspace == "" { - bindVars[engine.DBVarName] = sqltypes.NullBindVariable + bindVars[sqlparser.DBVarName] = sqltypes.NullBindVariable } else { - bindVars[engine.DBVarName] = sqltypes.StringBindVariable(keyspace) + bindVars[sqlparser.DBVarName] = sqltypes.StringBindVariable(keyspace) } } @@ -1399,7 +1414,7 @@ func (e *Executor) getPlan(vcursor *vcursorImpl, sql string, comments sqlparser. return nil, err } if !e.normalize { - plan, err := planbuilder.BuildFromStmt(sql, stmt, vcursor) + plan, err := planbuilder.BuildFromStmt(sql, stmt, vcursor, false, false) if err != nil { return nil, err } @@ -1410,8 +1425,12 @@ func (e *Executor) getPlan(vcursor *vcursorImpl, sql string, comments sqlparser. } // Normalize and retry. - sqlparser.Normalize(stmt, bindVars, "vtg") - normalized := sqlparser.String(stmt) + result, err := sqlparser.PrepareAST(stmt, bindVars, "vtg") + if err != nil { + return nil, vterrors.Wrap(err, "failed to rewrite ast before planning") + } + rewrittenStatement := result.AST + normalized := sqlparser.String(rewrittenStatement) if logStats != nil { logStats.SQL = comments.Leading + normalized + comments.Trailing @@ -1422,11 +1441,11 @@ func (e *Executor) getPlan(vcursor *vcursorImpl, sql string, comments sqlparser. if result, ok := e.plans.Get(planKey); ok { return result.(*engine.Plan), nil } - plan, err := planbuilder.BuildFromStmt(normalized, stmt, vcursor) + plan, err := planbuilder.BuildFromStmt(normalized, rewrittenStatement, vcursor, result.NeedLastInsertID, result.NeedDatabase) if err != nil { return nil, err } - if !skipQueryPlanCache && !sqlparser.SkipQueryPlanCacheDirective(stmt) { + if !skipQueryPlanCache && !sqlparser.SkipQueryPlanCacheDirective(rewrittenStatement) { e.plans.Set(planKey, plan) } return plan, nil diff --git a/go/vt/vtgate/executor_dml_test.go b/go/vt/vtgate/executor_dml_test.go index a4b68b44172..e5bba560ce5 100644 --- a/go/vt/vtgate/executor_dml_test.go +++ b/go/vt/vtgate/executor_dml_test.go @@ -1863,14 +1863,17 @@ func TestDeleteEqualWithPrepare(t *testing.T) { func TestUpdateLastInsertID(t *testing.T) { executor, sbc1, _, _ := createExecutorEnv() + executor.normalize = true sql := "update user set a = last_insert_id() where id = 1" masterSession.LastInsertId = 43 _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "update user set a = :__lastInsertId where id = 1 /* vtgate:: keyspace_id:166b40b44aba4bd6 */", - BindVariables: map[string]*querypb.BindVariable{"__lastInsertId": sqltypes.Uint64BindVariable(43)}, + Sql: "update user set a = :__lastInsertId where id = :vtg1 /* vtgate:: keyspace_id:166b40b44aba4bd6 */", + BindVariables: map[string]*querypb.BindVariable{ + "__lastInsertId": sqltypes.Uint64BindVariable(43), + "vtg1": sqltypes.Int64BindVariable(1)}, }} require.Equal(t, wantQueries, sbc1.Queries) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index cccd1effab6..a94f0b5c6d5 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -237,6 +237,7 @@ func TestStreamBuffering(t *testing.T) { func TestSelectLastInsertId(t *testing.T) { executor, sbc1, _, _ := createExecutorEnv() + executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) @@ -255,14 +256,14 @@ func TestSelectLastInsertId(t *testing.T) { func TestSelectLastInsertIdInUnion(t *testing.T) { executor, sbc1, _, _ := createExecutorEnv() - - sql := "select last_insert_id() as id union select id from user where 1 != 1" + executor.normalize = true + sql := "select last_insert_id() as id union select id from user" _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) if err != nil { t.Error(err) } wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__lastInsertId as id from dual union select id from user where 1 != 1", + Sql: "select :__lastInsertId as id from dual union select id from user", BindVariables: map[string]*querypb.BindVariable{"__lastInsertId": sqltypes.Uint64BindVariable(0)}, }} @@ -271,6 +272,7 @@ func TestSelectLastInsertIdInUnion(t *testing.T) { func TestSelectLastInsertIdInWhere(t *testing.T) { executor, _, _, lookup := createExecutorEnv() + executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) @@ -289,6 +291,7 @@ func TestSelectLastInsertIdInWhere(t *testing.T) { func TestLastInsertIDInVirtualTable(t *testing.T) { executor, sbc1, _, _ := createExecutorEnv() + executor.normalize = true result1 := []*sqltypes.Result{{ Fields: []*querypb.Field{ {Name: "id", Type: sqltypes.Int32}, @@ -316,6 +319,7 @@ func TestLastInsertIDInVirtualTable(t *testing.T) { func TestLastInsertIDInSubQueryExpression(t *testing.T) { executor, sbc1, _, _ := createExecutorEnv() + executor.normalize = true result1 := []*sqltypes.Result{{ Fields: []*querypb.Field{ {Name: "id", Type: sqltypes.Int32}, @@ -343,7 +347,7 @@ func TestLastInsertIDInSubQueryExpression(t *testing.T) { func TestSelectDatabase(t *testing.T) { executor, sbc1, _, _ := createExecutorEnv() - + executor.normalize = true sql := "select database()" newSession := *masterSession session := NewSafeSession(&newSession) diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 0e2b9a70263..dc3770b1e05 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -193,6 +193,26 @@ func TestExecutorTransactionsNoAutoCommit(t *testing.T) { } } +func TestDirectTargetRewrites(t *testing.T) { + executor, _, _, sbclookup := createExecutorEnv() + executor.normalize = true + + session := &vtgatepb.Session{ + TargetString: "TestUnsharded/0@master", + Autocommit: true, + TransactionMode: vtgatepb.TransactionMode_MULTI, + } + sql := "select database()" + + if _, err := executor.Execute(context.Background(), "TestExecute", NewSafeSession(session), sql, map[string]*querypb.BindVariable{}); err != nil { + t.Error(err) + } + testQueries(t, "sbclookup", sbclookup, []*querypb.BoundQuery{{ + Sql: "select :__vtdbname as `database()` from dual", + BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestUnsharded")}, + }}) +} + func TestExecutorTransactionsAutoCommit(t *testing.T) { executor, _, _, sbclookup := createExecutorEnv() session := NewSafeSession(&vtgatepb.Session{TargetString: "@master", Autocommit: true}) diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index 0db9dc94b96..dc912cd72e6 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -252,35 +252,38 @@ func (rsb *resultsBuilder) SupplyWeightString(colNumber int) (weightcolNumber in //------------------------------------------------------------------------- // Build builds a plan for a query based on the specified vschema. -// It's the main entry point for this package. +// This method is only used from tests func Build(query string, vschema ContextVSchema) (*engine.Plan, error) { stmt, err := sqlparser.Parse(query) if err != nil { return nil, err } - return BuildFromStmt(query, stmt, vschema) + result, err := sqlparser.RewriteAST(stmt) + if err != nil { + return nil, err + } + + return BuildFromStmt(query, result.AST, vschema, result.NeedLastInsertID, result.NeedDatabase) } // BuildFromStmt builds a plan based on the AST provided. // TODO(sougou): The query input is trusted as the source // of the AST. Maybe this function just returns instructions // and engine.Plan can be built by the caller. -func BuildFromStmt(query string, stmt sqlparser.Statement, vschema ContextVSchema) (*engine.Plan, error) { +func BuildFromStmt(query string, stmt sqlparser.Statement, vschema ContextVSchema, needsLastInsertID, needsDBName bool) (*engine.Plan, error) { var err error var instruction engine.Primitive - var needsLastInsertID bool - var needsDBName bool switch stmt := stmt.(type) { case *sqlparser.Select: - instruction, needsLastInsertID, needsDBName, err = buildSelectPlan(stmt, vschema) + instruction, err = buildSelectPlan(stmt, vschema) case *sqlparser.Insert: - instruction, needsLastInsertID, needsDBName, err = buildInsertPlan(stmt, vschema) + instruction, err = buildInsertPlan(stmt, vschema) case *sqlparser.Update: - instruction, needsLastInsertID, needsDBName, err = buildUpdatePlan(stmt, vschema) + instruction, err = buildUpdatePlan(stmt, vschema) case *sqlparser.Delete: - instruction, needsLastInsertID, needsDBName, err = buildDeletePlan(stmt, vschema) + instruction, err = buildDeletePlan(stmt, vschema) case *sqlparser.Union: - instruction, needsLastInsertID, needsDBName, err = buildUnionPlan(stmt, vschema) + instruction, err = buildUnionPlan(stmt, vschema) case *sqlparser.Set: return nil, errors.New("unsupported construct: set") case *sqlparser.Show: diff --git a/go/vt/vtgate/planbuilder/delete.go b/go/vt/vtgate/planbuilder/delete.go index 0a76153e349..5dc44b3d6d6 100644 --- a/go/vt/vtgate/planbuilder/delete.go +++ b/go/vt/vtgate/planbuilder/delete.go @@ -29,31 +29,31 @@ import ( ) // buildDeletePlan builds the instructions for a DELETE statement. -func buildDeletePlan(del *sqlparser.Delete, vschema ContextVSchema) (_ *engine.Delete, needsLastInsertID bool, needDbName bool, _ error) { +func buildDeletePlan(del *sqlparser.Delete, vschema ContextVSchema) (*engine.Delete, error) { edel := &engine.Delete{} pb := newPrimitiveBuilder(vschema, newJointab(sqlparser.GetBindvars(del))) ro, err := pb.processDMLTable(del.TableExprs) if err != nil { - return nil, false, false, err + return nil, err } edel.Query = generateQuery(del) edel.Keyspace = ro.eroute.Keyspace if !edel.Keyspace.Sharded { // We only validate non-table subexpressions because the previous analysis has already validated them. if !pb.finalizeUnshardedDMLSubqueries(del.Targets, del.Where, del.OrderBy, del.Limit) { - return nil, false, false, errors.New("unsupported: sharded subqueries in DML") + return nil, errors.New("unsupported: sharded subqueries in DML") } edel.Opcode = engine.DeleteUnsharded // Generate query after all the analysis. Otherwise table name substitutions for // routed tables won't happen. edel.Query = generateQuery(del) - return edel, pb.needsLastInsertID, pb.needsDbName, nil + return edel, nil } if del.Targets != nil || ro.vschemaTable == nil { - return nil, false, false, errors.New("unsupported: multi-table delete statement in sharded keyspace") + return nil, errors.New("unsupported: multi-table delete statement in sharded keyspace") } if hasSubquery(del) { - return nil, false, false, errors.New("unsupported: subqueries in sharded DML") + return nil, errors.New("unsupported: subqueries in sharded DML") } edel.Table = ro.vschemaTable // Generate query after all the analysis. Otherwise table name substitutions for @@ -68,11 +68,11 @@ func buildDeletePlan(del *sqlparser.Delete, vschema ContextVSchema) (_ *engine.D edel.QueryTimeout = queryTimeout(directives) if ro.eroute.TargetDestination != nil { if ro.eroute.TargetTabletType != topodatapb.TabletType_MASTER { - return nil, false, false, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported: DELETE statement with a replica target") + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported: DELETE statement with a replica target") } edel.Opcode = engine.DeleteByDestination edel.TargetDestination = ro.eroute.TargetDestination - return edel, pb.needsLastInsertID, pb.needsDbName, nil + return edel, nil } edel.Vindex, edel.Values, err = getDMLRouting(del.Where, edel.Table) // We couldn't generate a route for a single shard @@ -85,15 +85,15 @@ func buildDeletePlan(del *sqlparser.Delete, vschema ContextVSchema) (_ *engine.D if edel.Opcode == engine.DeleteScatter { if len(edel.Table.Owned) != 0 { - return nil, false, false, errors.New("unsupported: multi shard delete on a table with owned lookup vindexes") + return nil, errors.New("unsupported: multi shard delete on a table with owned lookup vindexes") } if del.Limit != nil { - return nil, false, false, errors.New("unsupported: multi shard delete with limit") + return nil, errors.New("unsupported: multi shard delete with limit") } } edel.OwnedVindexQuery = generateDeleteSubquery(del, edel.Table) - return edel, pb.needsLastInsertID, pb.needsDbName, nil + return edel, nil } // generateDeleteSubquery generates the query to fetch the rows diff --git a/go/vt/vtgate/planbuilder/expr.go b/go/vt/vtgate/planbuilder/expr.go index ad3e8b76efa..ad06a4f7219 100644 --- a/go/vt/vtgate/planbuilder/expr.go +++ b/go/vt/vtgate/planbuilder/expr.go @@ -149,7 +149,6 @@ func (pb *primitiveBuilder) findOrigin(expr sqlparser.Expr) (pullouts []*pullout } } subqueries = append(subqueries, sqi) - pb.copyBindVarNeeds(spb) return false, nil } return true, nil @@ -313,12 +312,12 @@ func valEqual(a, b sqlparser.Expr) bool { switch a.Type { case sqlparser.ValArg: if b.Type == sqlparser.ValArg { - return bytes.Equal([]byte(a.Val), []byte(b.Val)) + return bytes.Equal(a.Val, b.Val) } case sqlparser.StrVal: switch b.Type { case sqlparser.StrVal: - return bytes.Equal([]byte(a.Val), []byte(b.Val)) + return bytes.Equal(a.Val, b.Val) case sqlparser.HexVal: return hexEqual(b, a) } @@ -326,7 +325,7 @@ func valEqual(a, b sqlparser.Expr) bool { return hexEqual(a, b) case sqlparser.IntVal: if b.Type == (sqlparser.IntVal) { - return bytes.Equal([]byte(a.Val), []byte(b.Val)) + return bytes.Equal(a.Val, b.Val) } } } diff --git a/go/vt/vtgate/planbuilder/expression_rewriting.go b/go/vt/vtgate/planbuilder/expression_rewriting.go deleted file mode 100644 index 7e3a3dd0425..00000000000 --- a/go/vt/vtgate/planbuilder/expression_rewriting.go +++ /dev/null @@ -1,96 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package planbuilder - -import ( - "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/engine" -) - -// RewriteResult contains the rewritten expression and meta information about it -type RewriteResult struct { - Expression sqlparser.Expr - NeedLastInsertID bool - NeedDatabase bool -} - -// UpdateBindVarNeeds copies bind var needs from primitiveBuilders used for subqueries -func (rr *RewriteResult) UpdateBindVarNeeds(pb *primitiveBuilder) { - pb.needsDbName = pb.needsDbName || rr.NeedDatabase - pb.needsLastInsertID = pb.needsLastInsertID || rr.NeedLastInsertID -} - -// RewriteAndUpdateBuilder rewrites expressions and updates the primitive builder to remember what bindvar needs it has -func RewriteAndUpdateBuilder(in sqlparser.Expr, pb *primitiveBuilder) (sqlparser.Expr, error) { - out, err := Rewrite(in) - if err != nil { - return nil, err - } - out.UpdateBindVarNeeds(pb) - return out.Expression, nil -} - -// Rewrite will rewrite an expression. Currently it does the following rewrites: -// - `last_insert_id()` => `:__lastInsertId` -// - `database()` => `:__vtdbname` -func Rewrite(in sqlparser.Expr) (*RewriteResult, error) { - rewrites := make(map[*sqlparser.FuncExpr]sqlparser.Expr) - liid := false - db := false - - err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - switch node := node.(type) { - case *sqlparser.FuncExpr: - switch { - case node.Name.EqualString("last_insert_id"): - if len(node.Exprs) > 0 { - return false, vterrors.New(vtrpc.Code_UNIMPLEMENTED, "Argument to LAST_INSERT_ID() not supported") - } - rewrites[node] = bindVarExpression(engine.LastInsertIDName) - liid = true - case node.Name.EqualString("database"): - if len(node.Exprs) > 0 { - return false, vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. DATABASE() takes no arguments") - } - rewrites[node] = bindVarExpression(engine.DBVarName) - db = true - } - return true, nil - } - return true, nil - }, in) - - if err != nil { - return nil, err - } - - for from, to := range rewrites { - in = sqlparser.ReplaceExpr(in, from, to) - } - - return &RewriteResult{ - Expression: in, - NeedLastInsertID: liid, - NeedDatabase: db, - }, nil -} - -func bindVarExpression(name string) *sqlparser.SQLVal { - return sqlparser.NewValArg([]byte(":" + name)) -} diff --git a/go/vt/vtgate/planbuilder/expression_rewriting_test.go b/go/vt/vtgate/planbuilder/expression_rewriting_test.go deleted file mode 100644 index cb897aa00b3..00000000000 --- a/go/vt/vtgate/planbuilder/expression_rewriting_test.go +++ /dev/null @@ -1,133 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package planbuilder - -import ( - "testing" - - "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vtgate/engine" - - "github.com/stretchr/testify/assert" -) - -func TestDoesNotRewrite(t *testing.T) { - // SELECT 1 - literal := newIntVal("1") - result, err := Rewrite(literal) - assert.NoError(t, err) - assert.Equal(t, literal, result.Expression) - assert.False(t, result.NeedLastInsertID, "should not need last insert id") -} - -func liidBindVar() sqlparser.Expr { - return sqlparser.NewValArg([]byte(":" + engine.LastInsertIDName)) -} - -func TestRewriteLIID(t *testing.T) { - // SELECT last_insert_id() - lastInsertID := &sqlparser.FuncExpr{ - Qualifier: sqlparser.TableIdent{}, - Name: sqlparser.NewColIdent("last_insert_id"), - Distinct: false, - Exprs: nil, - } - result, err := Rewrite(lastInsertID) - assert.NoError(t, err) - assert.Equal(t, liidBindVar(), result.Expression) - assert.True(t, result.NeedLastInsertID, "should need last insert id") -} - -func TestRewriteLIIDComplex(t *testing.T) { - // SELECT 1 + last_insert_id() - lastInsertID := &sqlparser.FuncExpr{ - Qualifier: sqlparser.TableIdent{}, - Name: sqlparser.NewColIdent("last_insert_id"), - Distinct: false, - Exprs: nil, - } - expr := &sqlparser.BinaryExpr{ - Operator: "+", - Left: sqlparser.NewIntVal([]byte("1")), - Right: lastInsertID, - } - - result, err := Rewrite(expr) - - assert.NoError(t, err) - - expected := &sqlparser.BinaryExpr{ - Operator: "+", - Left: sqlparser.NewIntVal([]byte("1")), - Right: liidBindVar(), - } - - assert.Equal(t, expected, result.Expression) - assert.True(t, result.NeedLastInsertID, "should need last insert id") -} - -func TestRewriteLIIDComplex2(t *testing.T) { - // SELECT last_insert_id() + last_insert_id() - lastInsertID1 := &sqlparser.FuncExpr{ - Qualifier: sqlparser.TableIdent{}, - Name: sqlparser.NewColIdent("last_insert_id"), - Distinct: false, - Exprs: nil, - } - lastInsertID2 := &sqlparser.FuncExpr{ - Qualifier: sqlparser.TableIdent{}, - Name: sqlparser.NewColIdent("last_insert_id"), - Distinct: false, - Exprs: nil, - } - expr := &sqlparser.BinaryExpr{ - Operator: "+", - Left: lastInsertID1, - Right: lastInsertID2, - } - - result, err := Rewrite(expr) - - assert.NoError(t, err) - - expected := &sqlparser.BinaryExpr{ - Operator: "+", - Left: liidBindVar(), - Right: liidBindVar(), - } - - assert.Equal(t, expected, result.Expression) - assert.True(t, result.NeedLastInsertID, "should need last insert id") -} - -func TestRewriteDatabaseFunc(t *testing.T) { - // SELECT database() - database := &sqlparser.FuncExpr{ - Qualifier: sqlparser.TableIdent{}, - Name: sqlparser.NewColIdent("database"), - Distinct: false, - Exprs: nil, - } - result, err := Rewrite(database) - assert.NoError(t, err) - assert.Equal(t, databaseBindVar(), result.Expression) - assert.True(t, result.NeedDatabase, "should need database name") -} - -func databaseBindVar() sqlparser.Expr { - return sqlparser.NewValArg([]byte(":" + engine.DBVarName)) -} diff --git a/go/vt/vtgate/planbuilder/from.go b/go/vt/vtgate/planbuilder/from.go index 4e06fae7b78..772d224f13a 100644 --- a/go/vt/vtgate/planbuilder/from.go +++ b/go/vt/vtgate/planbuilder/from.go @@ -123,7 +123,7 @@ func (pb *primitiveBuilder) processAliasedTable(tableExpr *sqlparser.AliasedTabl // build a route primitive that has the subquery in its // FROM clause. This allows for other constructs to be // later pushed into it. - rb, st := newRoute(&sqlparser.Select{From: sqlparser.TableExprs([]sqlparser.TableExpr{tableExpr})}) + rb, st := newRoute(&sqlparser.Select{From: []sqlparser.TableExpr{tableExpr}}) // The subquery needs to be represented as a new logical table in the symtab. // The new route will inherit the routeOptions of the underlying subquery. @@ -166,7 +166,6 @@ func (pb *primitiveBuilder) processAliasedTable(tableExpr *sqlparser.AliasedTabl rb.routeOptions = subroute.routeOptions subroute.Redirect = rb pb.bldr, pb.st = rb, st - pb.copyBindVarNeeds(spb) return nil } return fmt.Errorf("BUG: unexpected table expression type: %T", tableExpr.Expr) diff --git a/go/vt/vtgate/planbuilder/insert.go b/go/vt/vtgate/planbuilder/insert.go index f96ad4f81ac..cbfc1890927 100644 --- a/go/vt/vtgate/planbuilder/insert.go +++ b/go/vt/vtgate/planbuilder/insert.go @@ -29,47 +29,31 @@ import ( ) // buildInsertPlan builds the route for an INSERT statement. -func buildInsertPlan(ins *sqlparser.Insert, vschema ContextVSchema) (_ engine.Primitive, needsLastInsertID bool, needsDBName bool, _ error) { +func buildInsertPlan(ins *sqlparser.Insert, vschema ContextVSchema) (engine.Primitive, error) { pb := newPrimitiveBuilder(vschema, newJointab(sqlparser.GetBindvars(ins))) exprs := sqlparser.TableExprs{&sqlparser.AliasedTableExpr{Expr: ins.Table}} ro, err := pb.processDMLTable(exprs) if err != nil { - return nil, false, false, err + return nil, err } // The table might have been routed to a different one. ins.Table = exprs[0].(*sqlparser.AliasedTableExpr).Expr.(sqlparser.TableName) if ro.eroute.TargetDestination != nil { - return nil, false, false, errors.New("unsupported: INSERT with a target destination") + return nil, errors.New("unsupported: INSERT with a target destination") } if !ro.vschemaTable.Keyspace.Sharded { if !pb.finalizeUnshardedDMLSubqueries(ins) { - return nil, false, false, errors.New("unsupported: sharded subquery in insert values") + return nil, errors.New("unsupported: sharded subquery in insert values") } return buildInsertUnshardedPlan(ins, ro.vschemaTable) } if ins.Action == sqlparser.ReplaceStr { - return nil, false, false, errors.New("unsupported: REPLACE INTO with sharded schema") + return nil, errors.New("unsupported: REPLACE INTO with sharded schema") } return buildInsertShardedPlan(ins, ro.vschemaTable) } -// rewriteValues will go over the insert values and rewrite them when needed -func rewriteValues(in sqlparser.Values) (_ sqlparser.Values, needsLastInsertID bool, needsDBName bool, _ error) { - for i, row := range in { - for j, val := range row { - rewritten, err := Rewrite(val) - if err != nil { - return nil, false, false, vterrors.Wrap(err, "failed to rewrite insert value") - } - in[i][j] = rewritten.Expression - needsLastInsertID = needsLastInsertID || rewritten.NeedLastInsertID - needsDBName = needsDBName || rewritten.NeedDatabase - } - } - return in, needsLastInsertID, needsDBName, nil -} - -func buildInsertUnshardedPlan(ins *sqlparser.Insert, table *vindexes.Table) (_ engine.Primitive, needsLastInsertID bool, needsDBName bool, _ error) { +func buildInsertUnshardedPlan(ins *sqlparser.Insert, table *vindexes.Table) (engine.Primitive, error) { eins := engine.NewSimpleInsert( engine.InsertUnsharded, table, @@ -79,18 +63,14 @@ func buildInsertUnshardedPlan(ins *sqlparser.Insert, table *vindexes.Table) (_ e switch insertValues := ins.Rows.(type) { case *sqlparser.Select, *sqlparser.Union: if eins.Table.AutoIncrement != nil { - return nil, false, false, errors.New("unsupported: auto-inc and select in insert") + return nil, errors.New("unsupported: auto-inc and select in insert") } eins.Query = generateQuery(ins) - return eins, false, false, nil + return eins, nil case sqlparser.Values: - var err error - rows, needsLastInsertID, needsDBName, err = rewriteValues(insertValues) - if err != nil { - return nil, false, false, err - } + rows = insertValues default: - return nil, false, false, fmt.Errorf("BUG: unexpected construct in insert: %T", insertValues) + return nil, fmt.Errorf("BUG: unexpected construct in insert: %T", insertValues) } if eins.Table.AutoIncrement == nil { eins.Query = generateQuery(ins) @@ -100,24 +80,24 @@ func buildInsertUnshardedPlan(ins *sqlparser.Insert, table *vindexes.Table) (_ e if table.ColumnListAuthoritative { populateInsertColumnlist(ins, table) } else { - return nil, false, false, errors.New("column list required for tables with auto-inc columns") + return nil, errors.New("column list required for tables with auto-inc columns") } } for _, row := range rows { if len(ins.Columns) != len(row) { - return nil, false, false, errors.New("column list doesn't match values") + return nil, errors.New("column list doesn't match values") } } if err := modifyForAutoinc(ins, eins); err != nil { - return nil, false, false, err + return nil, err } eins.Query = generateQuery(ins) } - return eins, needsLastInsertID, needsDBName, nil + return eins, nil } -func buildInsertShardedPlan(ins *sqlparser.Insert, table *vindexes.Table) (_ engine.Primitive, needsLastInsertID bool, needsDBName bool, _ error) { +func buildInsertShardedPlan(ins *sqlparser.Insert, table *vindexes.Table) (engine.Primitive, error) { eins := engine.NewSimpleInsert( engine.InsertSharded, table, @@ -128,7 +108,7 @@ func buildInsertShardedPlan(ins *sqlparser.Insert, table *vindexes.Table) (_ eng } if ins.OnDup != nil { if isVindexChanging(sqlparser.UpdateExprs(ins.OnDup), eins.Table.ColumnVindexes) { - return nil, false, false, errors.New("unsupported: DML cannot change vindex column") + return nil, errors.New("unsupported: DML cannot change vindex column") } eins.Opcode = engine.InsertShardedIgnore } @@ -136,7 +116,7 @@ func buildInsertShardedPlan(ins *sqlparser.Insert, table *vindexes.Table) (_ eng if table.ColumnListAuthoritative { populateInsertColumnlist(ins, table) } else { - return nil, false, false, errors.New("no column list") + return nil, errors.New("no column list") } } @@ -150,28 +130,24 @@ func buildInsertShardedPlan(ins *sqlparser.Insert, table *vindexes.Table) (_ eng var rows sqlparser.Values switch insertValues := ins.Rows.(type) { case *sqlparser.Select, *sqlparser.Union: - return nil, false, false, errors.New("unsupported: insert into select") + return nil, errors.New("unsupported: insert into select") case sqlparser.Values: - var err error - rows, needsLastInsertID, needsDBName, err = rewriteValues(insertValues) - if err != nil { - return nil, false, false, err - } + rows = insertValues if hasSubquery(rows) { - return nil, false, false, errors.New("unsupported: subquery in insert values") + return nil, errors.New("unsupported: subquery in insert values") } default: - return nil, false, false, fmt.Errorf("BUG: unexpected construct in insert: %T", insertValues) + return nil, fmt.Errorf("BUG: unexpected construct in insert: %T", insertValues) } for _, value := range rows { if len(ins.Columns) != len(value) { - return nil, false, false, errors.New("column list doesn't match values") + return nil, errors.New("column list doesn't match values") } } if eins.Table.AutoIncrement != nil { if err := modifyForAutoinc(ins, eins); err != nil { - return nil, false, false, err + return nil, err } } @@ -185,7 +161,7 @@ func buildInsertShardedPlan(ins *sqlparser.Insert, table *vindexes.Table) (_ eng for rowNum, row := range rows { innerpv, err := sqlparser.NewPlanValue(row[colNum]) if err != nil { - return nil, false, false, vterrors.Wrapf(err, "could not compute value for vindex or auto-inc column") + return nil, vterrors.Wrapf(err, "could not compute value for vindex or auto-inc column") } routeValues[vIdx].Values[colIdx].Values[rowNum] = innerpv } @@ -204,7 +180,7 @@ func buildInsertShardedPlan(ins *sqlparser.Insert, table *vindexes.Table) (_ eng eins.VindexValues = routeValues eins.Query = generateQuery(ins) generateInsertShardedQuery(ins, eins, rows) - return eins, needsLastInsertID, needsDBName, nil + return eins, nil } func populateInsertColumnlist(ins *sqlparser.Insert, table *vindexes.Table) { diff --git a/go/vt/vtgate/planbuilder/primitive_builder.go b/go/vt/vtgate/planbuilder/primitive_builder.go index c2b49f09da5..bed53cf8c0a 100644 --- a/go/vt/vtgate/planbuilder/primitive_builder.go +++ b/go/vt/vtgate/planbuilder/primitive_builder.go @@ -21,12 +21,10 @@ package planbuilder // the jointab. It can create transient planBuilders due // to the recursive nature of SQL. type primitiveBuilder struct { - vschema ContextVSchema - jt *jointab - bldr builder - st *symtab - needsLastInsertID bool - needsDbName bool + vschema ContextVSchema + jt *jointab + bldr builder + st *symtab } func newPrimitiveBuilder(vschema ContextVSchema, jt *jointab) *primitiveBuilder { @@ -35,8 +33,3 @@ func newPrimitiveBuilder(vschema ContextVSchema, jt *jointab) *primitiveBuilder jt: jt, } } - -func (pb *primitiveBuilder) copyBindVarNeeds(subQ *primitiveBuilder) { - pb.needsLastInsertID = pb.needsLastInsertID || subQ.needsLastInsertID - pb.needsDbName = pb.needsDbName || subQ.needsDbName -} diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 82b7db25273..f8746c7e3d9 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -20,22 +20,20 @@ import ( "errors" "fmt" - "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" ) // buildSelectPlan is the new function to build a Select plan. -func buildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (_ engine.Primitive, needsLastInsertID bool, needsDBName bool, _ error) { +func buildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (engine.Primitive, error) { pb := newPrimitiveBuilder(vschema, newJointab(sqlparser.GetBindvars(sel))) if err := pb.processSelect(sel, nil); err != nil { - return nil, false, false, err + return nil, err } if err := pb.bldr.Wireup(pb.bldr, pb.jt); err != nil { - return nil, false, false, err + return nil, err } - return pb.bldr.Primitive(), pb.needsLastInsertID, pb.needsDbName, nil + return pb.bldr.Primitive(), nil } // processSelect builds a primitive tree for the given query or subquery. @@ -126,11 +124,7 @@ func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab) // pushes it down, and updates the route info if the new constraint improves // the primitive. This function can push to a WHERE or HAVING clause. func (pb *primitiveBuilder) pushFilter(in sqlparser.Expr, whereType string) error { - rewritten, err := RewriteAndUpdateBuilder(in, pb) - if err != nil { - return vterrors.Wrap(err, "failed to Rewrite expressions") - } - filters := splitAndExpression(nil, rewritten) + filters := splitAndExpression(nil, in) reorderBySubquery(filters) for _, filter := range filters { pullouts, origin, expr, err := pb.findOrigin(filter) @@ -195,16 +189,7 @@ func (pb *primitiveBuilder) pushSelectRoutes(selectExprs sqlparser.SelectExprs) for _, node := range selectExprs { switch node := node.(type) { case *sqlparser.AliasedExpr: - rewritten, err := RewriteAndUpdateBuilder(node.Expr, pb) - if err != nil { - return nil, vterrors.Wrap(err, "failed to Rewrite expression") - } - if rewritten != node.Expr && node.As.IsEmpty() { - buf := sqlparser.NewTrackedBuffer(nil) - node.Expr.Format(buf) - node.As = sqlparser.NewColIdent(buf.String()) - } - pullouts, origin, expr, err := pb.findOrigin(rewritten) + pullouts, origin, expr, err := pb.findOrigin(node.Expr) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/union.go b/go/vt/vtgate/planbuilder/union.go index d26ea4463bb..dde7536613d 100644 --- a/go/vt/vtgate/planbuilder/union.go +++ b/go/vt/vtgate/planbuilder/union.go @@ -24,16 +24,16 @@ import ( "vitess.io/vitess/go/vt/vtgate/engine" ) -func buildUnionPlan(union *sqlparser.Union, vschema ContextVSchema) (primitive engine.Primitive, needLastInsertID bool, needsDbName bool, err error) { +func buildUnionPlan(union *sqlparser.Union, vschema ContextVSchema) (engine.Primitive, error) { // For unions, create a pb with anonymous scope. pb := newPrimitiveBuilder(vschema, newJointab(sqlparser.GetBindvars(union))) if err := pb.processUnion(union, nil); err != nil { - return nil, false, false, err + return nil, err } if err := pb.bldr.Wireup(pb.bldr, pb.jt); err != nil { - return nil, false, false, err + return nil, err } - return pb.bldr.Primitive(), pb.needsLastInsertID, pb.needsDbName, nil + return pb.bldr.Primitive(), nil } func (pb *primitiveBuilder) processUnion(union *sqlparser.Union, outer *symtab) error { diff --git a/go/vt/vtgate/planbuilder/update.go b/go/vt/vtgate/planbuilder/update.go index 22bb3ba5761..79bd4df5c36 100644 --- a/go/vt/vtgate/planbuilder/update.go +++ b/go/vt/vtgate/planbuilder/update.go @@ -30,37 +30,33 @@ import ( ) // buildUpdatePlan builds the instructions for an UPDATE statement. -func buildUpdatePlan(upd *sqlparser.Update, vschema ContextVSchema) (_ *engine.Update, needsLastInsertID bool, needDbName bool, _ error) { +func buildUpdatePlan(upd *sqlparser.Update, vschema ContextVSchema) (*engine.Update, error) { eupd := &engine.Update{ ChangedVindexValues: make(map[string][]sqltypes.PlanValue), } pb := newPrimitiveBuilder(vschema, newJointab(sqlparser.GetBindvars(upd))) - err := rewriteExpressions(pb, upd.Exprs) - if err != nil { - return nil, false, false, err - } ro, err := pb.processDMLTable(upd.TableExprs) if err != nil { - return nil, false, false, err + return nil, err } eupd.Keyspace = ro.eroute.Keyspace if !eupd.Keyspace.Sharded { // We only validate non-table subexpressions because the previous analysis has already validated them. if !pb.finalizeUnshardedDMLSubqueries(upd.Exprs, upd.Where, upd.OrderBy, upd.Limit) { - return nil, false, false, errors.New("unsupported: sharded subqueries in DML") + return nil, errors.New("unsupported: sharded subqueries in DML") } eupd.Opcode = engine.UpdateUnsharded // Generate query after all the analysis. Otherwise table name substitutions for // routed tables won't happen. eupd.Query = generateQuery(upd) - return eupd, pb.needsLastInsertID, pb.needsDbName, nil + return eupd, nil } if hasSubquery(upd) { - return nil, false, false, errors.New("unsupported: subqueries in sharded DML") + return nil, errors.New("unsupported: subqueries in sharded DML") } if len(pb.st.tables) != 1 { - return nil, false, false, errors.New("unsupported: multi-table update statement in sharded keyspace") + return nil, errors.New("unsupported: multi-table update statement in sharded keyspace") } // Generate query after all the analysis. Otherwise table name substitutions for @@ -75,16 +71,16 @@ func buildUpdatePlan(upd *sqlparser.Update, vschema ContextVSchema) (_ *engine.U eupd.QueryTimeout = queryTimeout(directives) eupd.Table = ro.vschemaTable if eupd.Table == nil { - return nil, false, false, errors.New("internal error: table.vindexTable is mysteriously nil") + return nil, errors.New("internal error: table.vindexTable is mysteriously nil") } if ro.eroute.TargetDestination != nil { if ro.eroute.TargetTabletType != topodatapb.TabletType_MASTER { - return nil, false, false, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported: UPDATE statement with a replica target") + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported: UPDATE statement with a replica target") } eupd.Opcode = engine.UpdateByDestination eupd.TargetDestination = ro.eroute.TargetDestination - return eupd, pb.needsLastInsertID, pb.needsDbName, nil + return eupd, nil } eupd.Vindex, eupd.Values, err = getDMLRouting(upd.Where, eupd.Table) @@ -96,31 +92,20 @@ func buildUpdatePlan(upd *sqlparser.Update, vschema ContextVSchema) (_ *engine.U if eupd.Opcode == engine.UpdateScatter { if len(eupd.Table.Owned) != 0 { - return eupd, false, false, errors.New("unsupported: multi shard update on a table with owned lookup vindexes") + return nil, errors.New("unsupported: multi shard update on a table with owned lookup vindexes") } if upd.Limit != nil { - return eupd, false, false, errors.New("unsupported: multi shard update with limit") + return nil, errors.New("unsupported: multi shard update with limit") } } if eupd.ChangedVindexValues, err = buildChangedVindexesValues(eupd, upd, eupd.Table.ColumnVindexes); err != nil { - return nil, false, false, err + return nil, err } if len(eupd.ChangedVindexValues) != 0 { eupd.OwnedVindexQuery = generateUpdateSubquery(upd, eupd.Table) } - return eupd, pb.needsLastInsertID, pb.needsDbName, nil -} - -func rewriteExpressions(pb *primitiveBuilder, exprs sqlparser.UpdateExprs) error { - for _, e := range exprs { - rewritten, err := RewriteAndUpdateBuilder(e.Expr, pb) - if err != nil { - return err - } - e.Expr = rewritten - } - return nil + return eupd, nil } // buildChangedVindexesValues adds to the plan all the lookup vindexes that are changing. diff --git a/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt b/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt index 0090a3ebcdd..2e7f09d1db0 100644 --- a/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt +++ b/go/vt/vttablet/tabletserver/planbuilder/testdata/exec_cases.txt @@ -2362,3 +2362,18 @@ options:PassthroughDMLs # named locks are unsafe with server-side connection pooling "select get_lock('foo') from dual" "get_lock() not allowed" + +# select DISTINCT ((1,2),(1,2)) from dual; +"select DISTINCT ((1,2),(1,2)) from dual" +{ + "PlanID": "PASS_SELECT", + "TableName": "dual", + "Permissions": [ + { + "TableName": "dual", + "Role": 0 + } + ], + "FieldQuery": "select ((1, 2), (1, 2)) from dual where 1 != 1", + "FullQuery": "select distinct ((1, 2), (1, 2)) from dual limit :#maxLimit" +} diff --git a/misc/git/hooks/visitorgen b/misc/git/hooks/visitorgen new file mode 100755 index 00000000000..7eb1c41360f --- /dev/null +++ b/misc/git/hooks/visitorgen @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright 2019 The Vitess Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# this script, which should run before committing code, makes sure that the visitor is re-generated when the ast changes + +make visitor REWRITER=tmp_rewriter.go +if ! cmp -s "tmp_rewriter.go" "go/vt/sqlparser/rewriter.go"; then + echo "The ast.go has changed, but not rewriter.go" + echo "You should 'make visitor' to update the generated rewriter" + rm -f tmp_rewriter.go + exit 1 +fi +rm -f tmp_rewriter.go \ No newline at end of file