diff --git a/.github/workflows/check_make_visitor.yml b/.github/workflows/check_make_visitor.yml index 219e1ece2dc..a23a8680a2e 100644 --- a/.github/workflows/check_make_visitor.yml +++ b/.github/workflows/check_make_visitor.yml @@ -31,5 +31,5 @@ jobs: - name: check_make_visitor run: | - misc/git/hooks/visitorgen + misc/git/hooks/asthelpers diff --git a/Makefile b/Makefile index cecb52b0e21..e83727c811d 100644 --- a/Makefile +++ b/Makefile @@ -103,6 +103,10 @@ parser: make -C go/vt/sqlparser visitor: + >&2 echo "make visitor has been replaced by make asthelpers" + exit 1 + +asthelpers: go run ./go/tools/asthelpergen/main -in ./go/vt/sqlparser -iface vitess.io/vitess/go/vt/sqlparser.SQLNode -except "*ColName" sizegen: @@ -123,7 +127,6 @@ clean: go clean -i ./go/... rm -rf third_party/acolyte rm -rf go/vt/.proto.tmp - rm -rf ./visitorgen # Remove everything including stuff pulled down by bootstrap.sh cleanall: clean diff --git a/go/tools/asthelpergen/asthelpergen.go b/go/tools/asthelpergen/asthelpergen.go index 9b1768a660a..e22c48936e1 100644 --- a/go/tools/asthelpergen/asthelpergen.go +++ b/go/tools/asthelpergen/asthelpergen.go @@ -23,6 +23,7 @@ import ( "io/ioutil" "log" "path" + "sort" "strings" "github.com/dave/jennifer/jen" @@ -43,52 +44,65 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.` -type generator interface { - visitStruct(t types.Type, stroct *types.Struct) error - visitInterface(t types.Type, iface *types.Interface) error - visitSlice(t types.Type, slice *types.Slice) error - createFile(pkgName string) (string, *jen.File) -} +type ( + generatorSPI interface { + addType(t types.Type) + addFunc(name string, t methodType, code jen.Code) + scope() *types.Scope + findImplementations(iff *types.Interface, impl func(types.Type) error) error + iface() *types.Interface + } + generator2 interface { + interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error + structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error + ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error + ptrToBasicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error + sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error + basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error + } + // astHelperGen finds implementations of the given interface, + // and uses the supplied `generator`s to produce the output code + astHelperGen struct { + DebugTypes bool + mod *packages.Module + sizes types.Sizes + namedIface *types.Named + _iface *types.Interface + gens []generator2 + + functions methods + _scope *types.Scope + todo []types.Type + } -type generatorSPI interface { - addType(t types.Type) - addFunc(name string, t methodType, code jen.Code) - scope() *types.Scope - findImplementations(iff *types.Interface, impl func(types.Type) error) error - iface() *types.Interface + method struct { + name string + code jen.Code + typ methodType + } + + methods []method +) + +func (m methods) Len() int { + return len(m) } -type generator2 interface { - interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error - structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error - ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error - ptrToBasicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error - ptrToOtherMethod(t types.Type, ptr *types.Pointer, spi generatorSPI) error - sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error - basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error +func (m methods) Less(i, j int) bool { + return m[i].name < m[j].name } -// astHelperGen finds implementations of the given interface, -// and uses the supplied `generator`s to produce the output code -type astHelperGen struct { - DebugTypes bool - mod *packages.Module - sizes types.Sizes - namedIface *types.Named - _iface *types.Interface - gens []generator - gens2 []generator2 - - methods []jen.Code - _scope *types.Scope - todo []types.Type +func (m methods) Swap(i, j int) { + m[i], m[j] = m[j], m[i] } +var _ sort.Interface = (methods)(nil) + func (gen *astHelperGen) iface() *types.Interface { return gen._iface } -func newGenerator(mod *packages.Module, sizes types.Sizes, named *types.Named, generators ...generator) *astHelperGen { +func newGenerator(mod *packages.Module, sizes types.Sizes, named *types.Named, generators ...generator2) *astHelperGen { return &astHelperGen{ DebugTypes: true, mod: mod, @@ -150,73 +164,11 @@ func (gen *astHelperGen) findImplementations(iff *types.Interface, impl func(typ return nil } -func (gen *astHelperGen) visitStruct(t types.Type, stroct *types.Struct) error { - for _, g := range gen.gens { - err := g.visitStruct(t, stroct) - if err != nil { - return err - } - } - return nil -} - -func (gen *astHelperGen) visitSlice(t types.Type, slice *types.Slice) error { - for _, g := range gen.gens { - err := g.visitSlice(t, slice) - if err != nil { - return err - } - } - return nil -} - -func (gen *astHelperGen) visitInterface(t types.Type, iface *types.Interface) error { - for _, g := range gen.gens { - err := g.visitInterface(t, iface) - if err != nil { - return err - } - } - return nil -} - // GenerateCode is the main loop where we build up the code per file. func (gen *astHelperGen) GenerateCode() (map[string]*jen.File, error) { pkg := gen.namedIface.Obj().Pkg() - iface, ok := gen._iface.Underlying().(*types.Interface) - if !ok { - return nil, fmt.Errorf("expected interface, but got %T", gen.iface) - } - - err := findImplementations(pkg.Scope(), iface, func(t types.Type) error { - switch n := t.Underlying().(type) { - case *types.Struct: - return gen.visitStruct(t, n) - case *types.Slice: - return gen.visitSlice(t, n) - case *types.Pointer: - strct, isStrct := n.Elem().Underlying().(*types.Struct) - if isStrct { - return gen.visitStruct(t, strct) - } - case *types.Interface: - return gen.visitInterface(t, n) - default: - // do nothing - } - return nil - }) - - if err != nil { - return nil, err - } result := map[string]*jen.File{} - for _, g := range gen.gens { - file, code := g.createFile(pkg.Name()) - fullPath := path.Join(gen.mod.Dir, strings.TrimPrefix(pkg.Path(), gen.mod.Path), file) - result[fullPath] = code - } gen._scope = pkg.Scope() gen.todo = append(gen.todo, gen.namedIface) @@ -300,16 +252,12 @@ func GenerateASTHelpers(packagePatterns []string, rootIface, exceptCloneType str nt := tt.Type().(*types.Named) - iface := nt.Underlying().(*types.Interface) - - interestingType := func(t types.Type) bool { - return types.Implements(t, iface) - } - rewriter := newRewriterGen(interestingType, nt.Obj().Name()) - generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt, rewriter) - generator.gens2 = append(generator.gens2, &equalsGen{}) - generator.gens2 = append(generator.gens2, newCloneGen(exceptCloneType)) - generator.gens2 = append(generator.gens2, &visitGen{}) + generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt, + &equalsGen{}, + newCloneGen(exceptCloneType), + &visitGen{}, + &rewriteGen{types.TypeString(nt, noQualifier)}, + ) it, err := generator.GenerateCode() if err != nil { @@ -335,19 +283,11 @@ const ( clone methodType = iota equals visit + rewrite ) func (gen *astHelperGen) addFunc(name string, typ methodType, code jen.Code) { - var comment string - switch typ { - case clone: - comment = " creates a deep clone of the input." - case equals: - comment = " does deep equals between the two objects." - case visit: - comment = " will visit all parts of the AST" - } - gen.methods = append(gen.methods, jen.Comment(name+comment), code) + gen.functions = append(gen.functions, method{name: name, code: code, typ: typ}) } func (gen *astHelperGen) createFile(pkgName string) (string, *jen.File) { @@ -406,15 +346,23 @@ func (gen *astHelperGen) createFile(pkgName string) (string, *jen.File) { alreadyDone[typeName] = true } - for _, method := range gen.methods { - out.Add(method) + sort.Sort(gen.functions) + + for _, m := range gen.functions { + switch m.typ { + case clone: + out.Add(jen.Comment(fmt.Sprintf("%s creates a deep clone of the input.", m.name))) + case equals: + out.Add(jen.Comment(fmt.Sprintf("%s does deep equals between the two objects.", m.name))) + } + out.Add(m.code) } return "ast_helper.go", out } func (gen *astHelperGen) allGenerators(f func(g generator2) error) { - for _, g := range gen.gens2 { + for _, g := range gen.gens { err := f(g) if err != nil { diff --git a/go/tools/asthelpergen/equals_gen.go b/go/tools/asthelpergen/equals_gen.go index 9def354589f..24b89141b5c 100644 --- a/go/tools/asthelpergen/equals_gen.go +++ b/go/tools/asthelpergen/equals_gen.go @@ -235,7 +235,3 @@ func (e equalsGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSP func (e equalsGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error { return nil } - -func (e equalsGen) ptrToOtherMethod(types.Type, *types.Pointer, generatorSPI) error { - return nil -} diff --git a/go/tools/asthelpergen/integration/ast_helper.go b/go/tools/asthelpergen/integration/ast_helper.go index c65490df329..63ee62f998d 100644 --- a/go/tools/asthelpergen/integration/ast_helper.go +++ b/go/tools/asthelpergen/integration/ast_helper.go @@ -17,6 +17,216 @@ limitations under the License. package integration +import ( + vtrpc "vitess.io/vitess/go/vt/proto/vtrpc" + vterrors "vitess.io/vitess/go/vt/vterrors" +) + +// CloneAST creates a deep clone of the input. +func CloneAST(in AST) AST { + if in == nil { + return nil + } + switch in := in.(type) { + case BasicType: + return in + case Bytes: + return CloneBytes(in) + case InterfaceContainer: + return CloneInterfaceContainer(in) + case InterfaceSlice: + return CloneInterfaceSlice(in) + case *Leaf: + return CloneRefOfLeaf(in) + case LeafSlice: + return CloneLeafSlice(in) + case *NoCloneType: + return CloneRefOfNoCloneType(in) + case *RefContainer: + return CloneRefOfRefContainer(in) + case *RefSliceContainer: + return CloneRefOfRefSliceContainer(in) + case *SubImpl: + return CloneRefOfSubImpl(in) + case ValueContainer: + return CloneValueContainer(in) + case ValueSliceContainer: + return CloneValueSliceContainer(in) + default: + // this should never happen + return nil + } +} + +// CloneBytes creates a deep clone of the input. +func CloneBytes(n Bytes) Bytes { + res := make(Bytes, 0, len(n)) + copy(res, n) + return res +} + +// CloneInterfaceContainer creates a deep clone of the input. +func CloneInterfaceContainer(n InterfaceContainer) InterfaceContainer { + return *CloneRefOfInterfaceContainer(&n) +} + +// CloneInterfaceSlice creates a deep clone of the input. +func CloneInterfaceSlice(n InterfaceSlice) InterfaceSlice { + res := make(InterfaceSlice, 0, len(n)) + for _, x := range n { + res = append(res, CloneAST(x)) + } + return res +} + +// CloneLeafSlice creates a deep clone of the input. +func CloneLeafSlice(n LeafSlice) LeafSlice { + res := make(LeafSlice, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfLeaf(x)) + } + return res +} + +// CloneRefOfBool creates a deep clone of the input. +func CloneRefOfBool(n *bool) *bool { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfInterfaceContainer creates a deep clone of the input. +func CloneRefOfInterfaceContainer(n *InterfaceContainer) *InterfaceContainer { + if n == nil { + return nil + } + out := *n + out.v = n.v + return &out +} + +// CloneRefOfLeaf creates a deep clone of the input. +func CloneRefOfLeaf(n *Leaf) *Leaf { + if n == nil { + return nil + } + out := *n + return &out +} + +// CloneRefOfNoCloneType creates a deep clone of the input. +func CloneRefOfNoCloneType(n *NoCloneType) *NoCloneType { + return n +} + +// CloneRefOfRefContainer creates a deep clone of the input. +func CloneRefOfRefContainer(n *RefContainer) *RefContainer { + if n == nil { + return nil + } + out := *n + out.ASTType = CloneAST(n.ASTType) + out.ASTImplementationType = CloneRefOfLeaf(n.ASTImplementationType) + return &out +} + +// CloneRefOfRefSliceContainer creates a deep clone of the input. +func CloneRefOfRefSliceContainer(n *RefSliceContainer) *RefSliceContainer { + if n == nil { + return nil + } + out := *n + out.ASTElements = CloneSliceOfAST(n.ASTElements) + out.NotASTElements = CloneSliceOfInt(n.NotASTElements) + out.ASTImplementationElements = CloneSliceOfRefOfLeaf(n.ASTImplementationElements) + return &out +} + +// CloneRefOfSubImpl creates a deep clone of the input. +func CloneRefOfSubImpl(n *SubImpl) *SubImpl { + if n == nil { + return nil + } + out := *n + out.inner = CloneSubIface(n.inner) + out.field = CloneRefOfBool(n.field) + return &out +} + +// CloneRefOfValueContainer creates a deep clone of the input. +func CloneRefOfValueContainer(n *ValueContainer) *ValueContainer { + if n == nil { + return nil + } + out := *n + out.ASTType = CloneAST(n.ASTType) + out.ASTImplementationType = CloneRefOfLeaf(n.ASTImplementationType) + return &out +} + +// CloneRefOfValueSliceContainer creates a deep clone of the input. +func CloneRefOfValueSliceContainer(n *ValueSliceContainer) *ValueSliceContainer { + if n == nil { + return nil + } + out := *n + out.ASTElements = CloneSliceOfAST(n.ASTElements) + out.NotASTElements = CloneSliceOfInt(n.NotASTElements) + out.ASTImplementationElements = CloneSliceOfRefOfLeaf(n.ASTImplementationElements) + return &out +} + +// CloneSliceOfAST creates a deep clone of the input. +func CloneSliceOfAST(n []AST) []AST { + res := make([]AST, 0, len(n)) + for _, x := range n { + res = append(res, CloneAST(x)) + } + return res +} + +// CloneSliceOfInt creates a deep clone of the input. +func CloneSliceOfInt(n []int) []int { + res := make([]int, 0, len(n)) + copy(res, n) + return res +} + +// CloneSliceOfRefOfLeaf creates a deep clone of the input. +func CloneSliceOfRefOfLeaf(n []*Leaf) []*Leaf { + res := make([]*Leaf, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfLeaf(x)) + } + return res +} + +// CloneSubIface creates a deep clone of the input. +func CloneSubIface(in SubIface) SubIface { + if in == nil { + return nil + } + switch in := in.(type) { + case *SubImpl: + return CloneRefOfSubImpl(in) + default: + // this should never happen + return nil + } +} + +// CloneValueContainer creates a deep clone of the input. +func CloneValueContainer(n ValueContainer) ValueContainer { + return *CloneRefOfValueContainer(&n) +} + +// CloneValueSliceContainer creates a deep clone of the input. +func CloneValueSliceContainer(n ValueSliceContainer) ValueSliceContainer { + return *CloneRefOfValueSliceContainer(&n) +} + // EqualsAST does deep equals between the two objects. func EqualsAST(inA, inB AST) bool { if inA == nil && inB == nil { @@ -104,102 +314,17 @@ func EqualsAST(inA, inB AST) bool { } } -// CloneAST creates a deep clone of the input. -func CloneAST(in AST) AST { - if in == nil { - return nil +// EqualsBytes does deep equals between the two objects. +func EqualsBytes(a, b Bytes) bool { + if len(a) != len(b) { + return false } - switch in := in.(type) { - case BasicType: - return in - case Bytes: - return CloneBytes(in) - case InterfaceContainer: - return CloneInterfaceContainer(in) - case InterfaceSlice: - return CloneInterfaceSlice(in) - case *Leaf: - return CloneRefOfLeaf(in) - case LeafSlice: - return CloneLeafSlice(in) - case *NoCloneType: - return CloneRefOfNoCloneType(in) - case *RefContainer: - return CloneRefOfRefContainer(in) - case *RefSliceContainer: - return CloneRefOfRefSliceContainer(in) - case *SubImpl: - return CloneRefOfSubImpl(in) - case ValueContainer: - return CloneValueContainer(in) - case ValueSliceContainer: - return CloneValueSliceContainer(in) - default: - // this should never happen - return nil + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } } -} - -// VisitAST will visit all parts of the AST -func VisitAST(in AST, f Visit) error { - if in == nil { - return nil - } - switch in := in.(type) { - case BasicType: - return VisitBasicType(in, f) - case Bytes: - return VisitBytes(in, f) - case InterfaceContainer: - return VisitInterfaceContainer(in, f) - case InterfaceSlice: - return VisitInterfaceSlice(in, f) - case *Leaf: - return VisitRefOfLeaf(in, f) - case LeafSlice: - return VisitLeafSlice(in, f) - case *NoCloneType: - return VisitRefOfNoCloneType(in, f) - case *RefContainer: - return VisitRefOfRefContainer(in, f) - case *RefSliceContainer: - return VisitRefOfRefSliceContainer(in, f) - case *SubImpl: - return VisitRefOfSubImpl(in, f) - case ValueContainer: - return VisitValueContainer(in, f) - case ValueSliceContainer: - return VisitValueSliceContainer(in, f) - default: - // this should never happen - return nil - } -} - -// EqualsBytes does deep equals between the two objects. -func EqualsBytes(a, b Bytes) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false - } - } - return true -} - -// CloneBytes creates a deep clone of the input. -func CloneBytes(n Bytes) Bytes { - res := make(Bytes, 0, len(n)) - copy(res, n) - return res -} - -// VisitBytes will visit all parts of the AST -func VisitBytes(in Bytes, f Visit) error { - _, err := f(in) - return err + return true } // EqualsInterfaceContainer does deep equals between the two objects. @@ -207,19 +332,6 @@ func EqualsInterfaceContainer(a, b InterfaceContainer) bool { return true } -// CloneInterfaceContainer creates a deep clone of the input. -func CloneInterfaceContainer(n InterfaceContainer) InterfaceContainer { - return *CloneRefOfInterfaceContainer(&n) -} - -// VisitInterfaceContainer will visit all parts of the AST -func VisitInterfaceContainer(in InterfaceContainer, f Visit) error { - if cont, err := f(in); err != nil || !cont { - return err - } - return nil -} - // EqualsInterfaceSlice does deep equals between the two objects. func EqualsInterfaceSlice(a, b InterfaceSlice) bool { if len(a) != len(b) { @@ -233,98 +345,50 @@ func EqualsInterfaceSlice(a, b InterfaceSlice) bool { return true } -// CloneInterfaceSlice creates a deep clone of the input. -func CloneInterfaceSlice(n InterfaceSlice) InterfaceSlice { - res := make(InterfaceSlice, 0, len(n)) - for _, x := range n { - res = append(res, CloneAST(x)) - } - return res -} - -// VisitInterfaceSlice will visit all parts of the AST -func VisitInterfaceSlice(in InterfaceSlice, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsLeafSlice does deep equals between the two objects. +func EqualsLeafSlice(a, b LeafSlice) bool { + if len(a) != len(b) { + return false } - for _, el := range in { - if err := VisitAST(el, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if !EqualsRefOfLeaf(a[i], b[i]) { + return false } } - return nil + return true } -// EqualsRefOfLeaf does deep equals between the two objects. -func EqualsRefOfLeaf(a, b *Leaf) bool { +// EqualsRefOfBool does deep equals between the two objects. +func EqualsRefOfBool(a, b *bool) bool { if a == b { return true } if a == nil || b == nil { return false } - return a.v == b.v -} - -// CloneRefOfLeaf creates a deep clone of the input. -func CloneRefOfLeaf(n *Leaf) *Leaf { - if n == nil { - return nil - } - out := *n - return &out + return *a == *b } -// VisitRefOfLeaf will visit all parts of the AST -func VisitRefOfLeaf(in *Leaf, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfInterfaceContainer does deep equals between the two objects. +func EqualsRefOfInterfaceContainer(a, b *InterfaceContainer) bool { + if a == b { + return true } - return nil -} - -// EqualsLeafSlice does deep equals between the two objects. -func EqualsLeafSlice(a, b LeafSlice) bool { - if len(a) != len(b) { + if a == nil || b == nil { return false } - for i := 0; i < len(a); i++ { - if !EqualsRefOfLeaf(a[i], b[i]) { - return false - } - } return true } -// CloneLeafSlice creates a deep clone of the input. -func CloneLeafSlice(n LeafSlice) LeafSlice { - res := make(LeafSlice, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfLeaf(x)) - } - return res -} - -// VisitLeafSlice will visit all parts of the AST -func VisitLeafSlice(in LeafSlice, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfLeaf does deep equals between the two objects. +func EqualsRefOfLeaf(a, b *Leaf) bool { + if a == b { + return true } - for _, el := range in { - if err := VisitRefOfLeaf(el, f); err != nil { - return err - } + if a == nil || b == nil { + return false } - return nil + return a.v == b.v } // EqualsRefOfNoCloneType does deep equals between the two objects. @@ -338,22 +402,6 @@ func EqualsRefOfNoCloneType(a, b *NoCloneType) bool { return a.v == b.v } -// CloneRefOfNoCloneType creates a deep clone of the input. -func CloneRefOfNoCloneType(n *NoCloneType) *NoCloneType { - return n -} - -// VisitRefOfNoCloneType will visit all parts of the AST -func VisitRefOfNoCloneType(in *NoCloneType, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil -} - // EqualsRefOfRefContainer does deep equals between the two objects. func EqualsRefOfRefContainer(a, b *RefContainer) bool { if a == b { @@ -367,34 +415,6 @@ func EqualsRefOfRefContainer(a, b *RefContainer) bool { EqualsRefOfLeaf(a.ASTImplementationType, b.ASTImplementationType) } -// CloneRefOfRefContainer creates a deep clone of the input. -func CloneRefOfRefContainer(n *RefContainer) *RefContainer { - if n == nil { - return nil - } - out := *n - out.ASTType = CloneAST(n.ASTType) - out.ASTImplementationType = CloneRefOfLeaf(n.ASTImplementationType) - return &out -} - -// VisitRefOfRefContainer will visit all parts of the AST -func VisitRefOfRefContainer(in *RefContainer, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitAST(in.ASTType, f); err != nil { - return err - } - if err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil { - return err - } - return nil -} - // EqualsRefOfRefSliceContainer does deep equals between the two objects. func EqualsRefOfRefSliceContainer(a, b *RefSliceContainer) bool { if a == b { @@ -408,39 +428,6 @@ func EqualsRefOfRefSliceContainer(a, b *RefSliceContainer) bool { EqualsSliceOfRefOfLeaf(a.ASTImplementationElements, b.ASTImplementationElements) } -// CloneRefOfRefSliceContainer creates a deep clone of the input. -func CloneRefOfRefSliceContainer(n *RefSliceContainer) *RefSliceContainer { - if n == nil { - return nil - } - out := *n - out.ASTElements = CloneSliceOfAST(n.ASTElements) - out.NotASTElements = CloneSliceOfInt(n.NotASTElements) - out.ASTImplementationElements = CloneSliceOfRefOfLeaf(n.ASTImplementationElements) - return &out -} - -// VisitRefOfRefSliceContainer will visit all parts of the AST -func VisitRefOfRefSliceContainer(in *RefSliceContainer, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - for _, el := range in.ASTElements { - if err := VisitAST(el, f); err != nil { - return err - } - } - for _, el := range in.ASTImplementationElements { - if err := VisitRefOfLeaf(el, f); err != nil { - return err - } - } - return nil -} - // EqualsRefOfSubImpl does deep equals between the two objects. func EqualsRefOfSubImpl(a, b *SubImpl) bool { if a == b { @@ -453,85 +440,69 @@ func EqualsRefOfSubImpl(a, b *SubImpl) bool { EqualsRefOfBool(a.field, b.field) } -// CloneRefOfSubImpl creates a deep clone of the input. -func CloneRefOfSubImpl(n *SubImpl) *SubImpl { - if n == nil { - return nil +// EqualsRefOfValueContainer does deep equals between the two objects. +func EqualsRefOfValueContainer(a, b *ValueContainer) bool { + if a == b { + return true } - out := *n - out.inner = CloneSubIface(n.inner) - out.field = CloneRefOfBool(n.field) - return &out + if a == nil || b == nil { + return false + } + return a.NotASTType == b.NotASTType && + EqualsAST(a.ASTType, b.ASTType) && + EqualsRefOfLeaf(a.ASTImplementationType, b.ASTImplementationType) } -// VisitRefOfSubImpl will visit all parts of the AST -func VisitRefOfSubImpl(in *SubImpl, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsRefOfValueSliceContainer does deep equals between the two objects. +func EqualsRefOfValueSliceContainer(a, b *ValueSliceContainer) bool { + if a == b { + return true } - if err := VisitSubIface(in.inner, f); err != nil { - return err - } - return nil -} - -// EqualsValueContainer does deep equals between the two objects. -func EqualsValueContainer(a, b ValueContainer) bool { - return a.NotASTType == b.NotASTType && - EqualsAST(a.ASTType, b.ASTType) && - EqualsRefOfLeaf(a.ASTImplementationType, b.ASTImplementationType) -} - -// CloneValueContainer creates a deep clone of the input. -func CloneValueContainer(n ValueContainer) ValueContainer { - return *CloneRefOfValueContainer(&n) -} - -// VisitValueContainer will visit all parts of the AST -func VisitValueContainer(in ValueContainer, f Visit) error { - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitAST(in.ASTType, f); err != nil { - return err - } - if err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil { - return err + if a == nil || b == nil { + return false } - return nil -} - -// EqualsValueSliceContainer does deep equals between the two objects. -func EqualsValueSliceContainer(a, b ValueSliceContainer) bool { return EqualsSliceOfAST(a.ASTElements, b.ASTElements) && EqualsSliceOfInt(a.NotASTElements, b.NotASTElements) && EqualsSliceOfRefOfLeaf(a.ASTImplementationElements, b.ASTImplementationElements) } -// CloneValueSliceContainer creates a deep clone of the input. -func CloneValueSliceContainer(n ValueSliceContainer) ValueSliceContainer { - return *CloneRefOfValueSliceContainer(&n) +// EqualsSliceOfAST does deep equals between the two objects. +func EqualsSliceOfAST(a, b []AST) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsAST(a[i], b[i]) { + return false + } + } + return true } -// VisitValueSliceContainer will visit all parts of the AST -func VisitValueSliceContainer(in ValueSliceContainer, f Visit) error { - if cont, err := f(in); err != nil || !cont { - return err +// EqualsSliceOfInt does deep equals between the two objects. +func EqualsSliceOfInt(a, b []int) bool { + if len(a) != len(b) { + return false } - for _, el := range in.ASTElements { - if err := VisitAST(el, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false } } - for _, el := range in.ASTImplementationElements { - if err := VisitRefOfLeaf(el, f); err != nil { - return err + return true +} + +// EqualsSliceOfRefOfLeaf does deep equals between the two objects. +func EqualsSliceOfRefOfLeaf(a, b []*Leaf) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfLeaf(a[i], b[i]) { + return false } } - return nil + return true } // EqualsSubIface does deep equals between the two objects. @@ -555,181 +526,168 @@ func EqualsSubIface(inA, inB SubIface) bool { } } -// CloneSubIface creates a deep clone of the input. -func CloneSubIface(in SubIface) SubIface { - if in == nil { - return nil - } - switch in := in.(type) { - case *SubImpl: - return CloneRefOfSubImpl(in) - default: - // this should never happen - return nil - } +// EqualsValueContainer does deep equals between the two objects. +func EqualsValueContainer(a, b ValueContainer) bool { + return a.NotASTType == b.NotASTType && + EqualsAST(a.ASTType, b.ASTType) && + EqualsRefOfLeaf(a.ASTImplementationType, b.ASTImplementationType) } -// VisitSubIface will visit all parts of the AST -func VisitSubIface(in SubIface, f Visit) error { +// EqualsValueSliceContainer does deep equals between the two objects. +func EqualsValueSliceContainer(a, b ValueSliceContainer) bool { + return EqualsSliceOfAST(a.ASTElements, b.ASTElements) && + EqualsSliceOfInt(a.NotASTElements, b.NotASTElements) && + EqualsSliceOfRefOfLeaf(a.ASTImplementationElements, b.ASTImplementationElements) +} +func VisitAST(in AST, f Visit) error { if in == nil { return nil } switch in := in.(type) { + case BasicType: + return VisitBasicType(in, f) + case Bytes: + return VisitBytes(in, f) + case InterfaceContainer: + return VisitInterfaceContainer(in, f) + case InterfaceSlice: + return VisitInterfaceSlice(in, f) + case *Leaf: + return VisitRefOfLeaf(in, f) + case LeafSlice: + return VisitLeafSlice(in, f) + case *NoCloneType: + return VisitRefOfNoCloneType(in, f) + case *RefContainer: + return VisitRefOfRefContainer(in, f) + case *RefSliceContainer: + return VisitRefOfRefSliceContainer(in, f) case *SubImpl: return VisitRefOfSubImpl(in, f) + case ValueContainer: + return VisitValueContainer(in, f) + case ValueSliceContainer: + return VisitValueSliceContainer(in, f) default: // this should never happen return nil } } - -// VisitBasicType will visit all parts of the AST func VisitBasicType(in BasicType, f Visit) error { _, err := f(in) return err } - -// EqualsRefOfInterfaceContainer does deep equals between the two objects. -func EqualsRefOfInterfaceContainer(a, b *InterfaceContainer) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return true +func VisitBytes(in Bytes, f Visit) error { + _, err := f(in) + return err } - -// CloneRefOfInterfaceContainer creates a deep clone of the input. -func CloneRefOfInterfaceContainer(n *InterfaceContainer) *InterfaceContainer { - if n == nil { - return nil +func VisitInterfaceContainer(in InterfaceContainer, f Visit) error { + if cont, err := f(in); err != nil || !cont { + return err } - out := *n - out.v = n.v - return &out + return nil } - -// VisitRefOfInterfaceContainer will visit all parts of the AST -func VisitRefOfInterfaceContainer(in *InterfaceContainer, f Visit) error { +func VisitInterfaceSlice(in InterfaceSlice, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - return nil -} - -// EqualsSliceOfAST does deep equals between the two objects. -func EqualsSliceOfAST(a, b []AST) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsAST(a[i], b[i]) { - return false + for _, el := range in { + if err := VisitAST(el, f); err != nil { + return err } } - return true + return nil } - -// CloneSliceOfAST creates a deep clone of the input. -func CloneSliceOfAST(n []AST) []AST { - res := make([]AST, 0, len(n)) - for _, x := range n { - res = append(res, CloneAST(x)) +func VisitLeafSlice(in LeafSlice, f Visit) error { + if in == nil { + return nil } - return res -} - -// EqualsSliceOfInt does deep equals between the two objects. -func EqualsSliceOfInt(a, b []int) bool { - if len(a) != len(b) { - return false + if cont, err := f(in); err != nil || !cont { + return err } - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false + for _, el := range in { + if err := VisitRefOfLeaf(el, f); err != nil { + return err } } - return true -} - -// CloneSliceOfInt creates a deep clone of the input. -func CloneSliceOfInt(n []int) []int { - res := make([]int, 0, len(n)) - copy(res, n) - return res + return nil } - -// EqualsSliceOfRefOfLeaf does deep equals between the two objects. -func EqualsSliceOfRefOfLeaf(a, b []*Leaf) bool { - if len(a) != len(b) { - return false +func VisitRefOfInterfaceContainer(in *InterfaceContainer, f Visit) error { + if in == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfLeaf(a[i], b[i]) { - return false - } + if cont, err := f(in); err != nil || !cont { + return err } - return true + return nil } - -// CloneSliceOfRefOfLeaf creates a deep clone of the input. -func CloneSliceOfRefOfLeaf(n []*Leaf) []*Leaf { - res := make([]*Leaf, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfLeaf(x)) +func VisitRefOfLeaf(in *Leaf, f Visit) error { + if in == nil { + return nil } - return res + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// EqualsRefOfBool does deep equals between the two objects. -func EqualsRefOfBool(a, b *bool) bool { - if a == b { - return true +func VisitRefOfNoCloneType(in *NoCloneType, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return *a == *b + return nil } - -// CloneRefOfBool creates a deep clone of the input. -func CloneRefOfBool(n *bool) *bool { - if n == nil { +func VisitRefOfRefContainer(in *RefContainer, f Visit) error { + if in == nil { return nil } - out := *n - return &out -} - -// EqualsRefOfValueContainer does deep equals between the two objects. -func EqualsRefOfValueContainer(a, b *ValueContainer) bool { - if a == b { - return true + if cont, err := f(in); err != nil || !cont { + return err } - if a == nil || b == nil { - return false + if err := VisitAST(in.ASTType, f); err != nil { + return err } - return a.NotASTType == b.NotASTType && - EqualsAST(a.ASTType, b.ASTType) && - EqualsRefOfLeaf(a.ASTImplementationType, b.ASTImplementationType) + if err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil { + return err + } + return nil } - -// CloneRefOfValueContainer creates a deep clone of the input. -func CloneRefOfValueContainer(n *ValueContainer) *ValueContainer { - if n == nil { +func VisitRefOfRefSliceContainer(in *RefSliceContainer, f Visit) error { + if in == nil { return nil } - out := *n - out.ASTType = CloneAST(n.ASTType) - out.ASTImplementationType = CloneRefOfLeaf(n.ASTImplementationType) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + for _, el := range in.ASTElements { + if err := VisitAST(el, f); err != nil { + return err + } + } + for _, el := range in.ASTImplementationElements { + if err := VisitRefOfLeaf(el, f); err != nil { + return err + } + } + return nil +} +func VisitRefOfSubImpl(in *SubImpl, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitSubIface(in.inner, f); err != nil { + return err + } + return nil } - -// VisitRefOfValueContainer will visit all parts of the AST func VisitRefOfValueContainer(in *ValueContainer, f Visit) error { if in == nil { return nil @@ -745,37 +703,50 @@ func VisitRefOfValueContainer(in *ValueContainer, f Visit) error { } return nil } - -// EqualsRefOfValueSliceContainer does deep equals between the two objects. -func EqualsRefOfValueSliceContainer(a, b *ValueSliceContainer) bool { - if a == b { - return true +func VisitRefOfValueSliceContainer(in *ValueSliceContainer, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsSliceOfAST(a.ASTElements, b.ASTElements) && - EqualsSliceOfInt(a.NotASTElements, b.NotASTElements) && - EqualsSliceOfRefOfLeaf(a.ASTImplementationElements, b.ASTImplementationElements) -} - -// CloneRefOfValueSliceContainer creates a deep clone of the input. -func CloneRefOfValueSliceContainer(n *ValueSliceContainer) *ValueSliceContainer { - if n == nil { - return nil + for _, el := range in.ASTElements { + if err := VisitAST(el, f); err != nil { + return err + } } - out := *n - out.ASTElements = CloneSliceOfAST(n.ASTElements) - out.NotASTElements = CloneSliceOfInt(n.NotASTElements) - out.ASTImplementationElements = CloneSliceOfRefOfLeaf(n.ASTImplementationElements) - return &out + for _, el := range in.ASTImplementationElements { + if err := VisitRefOfLeaf(el, f); err != nil { + return err + } + } + return nil } - -// VisitRefOfValueSliceContainer will visit all parts of the AST -func VisitRefOfValueSliceContainer(in *ValueSliceContainer, f Visit) error { +func VisitSubIface(in SubIface, f Visit) error { if in == nil { return nil } + switch in := in.(type) { + case *SubImpl: + return VisitRefOfSubImpl(in, f) + default: + // this should never happen + return nil + } +} +func VisitValueContainer(in ValueContainer, f Visit) error { + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitAST(in.ASTType, f); err != nil { + return err + } + if err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil { + return err + } + return nil +} +func VisitValueSliceContainer(in ValueSliceContainer, f Visit) error { if cont, err := f(in); err != nil || !cont { return err } @@ -791,3 +762,482 @@ func VisitRefOfValueSliceContainer(in *ValueSliceContainer, f Visit) error { } return nil } +func (a *application) rewriteAST(parent AST, node AST, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case BasicType: + return a.rewriteBasicType(parent, node, replacer) + case Bytes: + return a.rewriteBytes(parent, node, replacer) + case InterfaceContainer: + return a.rewriteInterfaceContainer(parent, node, replacer) + case InterfaceSlice: + return a.rewriteInterfaceSlice(parent, node, replacer) + case *Leaf: + return a.rewriteRefOfLeaf(parent, node, replacer) + case LeafSlice: + return a.rewriteLeafSlice(parent, node, replacer) + case *NoCloneType: + return a.rewriteRefOfNoCloneType(parent, node, replacer) + case *RefContainer: + return a.rewriteRefOfRefContainer(parent, node, replacer) + case *RefSliceContainer: + return a.rewriteRefOfRefSliceContainer(parent, node, replacer) + case *SubImpl: + return a.rewriteRefOfSubImpl(parent, node, replacer) + case ValueContainer: + return a.rewriteValueContainer(parent, node, replacer) + case ValueSliceContainer: + return a.rewriteValueSliceContainer(parent, node, replacer) + default: + // this should never happen + return nil + } +} +func (a *application) rewriteBasicType(parent AST, node BasicType, replacer replacerFunc) error { + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteBytes(parent AST, node Bytes, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteInterfaceContainer(parent AST, node InterfaceContainer, replacer replacerFunc) error { + var err error + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if err != nil { + return err + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + for i, el := range node { + if errF := a.rewriteAST(node, el, func(newNode, parent AST) { + parent.(InterfaceSlice)[i] = newNode.(AST) + }); errF != nil { + return errF + } + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + for i, el := range node { + if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(LeafSlice)[i] = newNode.(*Leaf) + }); errF != nil { + return errF + } + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfInterfaceContainer(parent AST, node *InterfaceContainer, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfLeaf(parent AST, node *Leaf, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfNoCloneType(parent AST, node *NoCloneType, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfRefContainer(parent AST, node *RefContainer, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { + parent.(*RefContainer).ASTType = newNode.(AST) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + parent.(*RefContainer).ASTImplementationType = newNode.(*Leaf) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfRefSliceContainer(parent AST, node *RefSliceContainer, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + for i, el := range node.ASTElements { + if errF := a.rewriteAST(node, el, func(newNode, parent AST) { + parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) + }); errF != nil { + return errF + } + } + for i, el := range node.ASTImplementationElements { + if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(*RefSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) + }); errF != nil { + return errF + } + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfSubImpl(parent AST, node *SubImpl, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteSubIface(node, node.inner, func(newNode, parent AST) { + parent.(*SubImpl).inner = newNode.(SubIface) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfValueContainer(parent AST, node *ValueContainer, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { + parent.(*ValueContainer).ASTType = newNode.(AST) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + parent.(*ValueContainer).ASTImplementationType = newNode.(*Leaf) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfValueSliceContainer(parent AST, node *ValueSliceContainer, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + for i, el := range node.ASTElements { + if errF := a.rewriteAST(node, el, func(newNode, parent AST) { + parent.(*ValueSliceContainer).ASTElements[i] = newNode.(AST) + }); errF != nil { + return errF + } + } + for i, el := range node.ASTImplementationElements { + if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(*ValueSliceContainer).ASTImplementationElements[i] = newNode.(*Leaf) + }); errF != nil { + return errF + } + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteSubIface(parent AST, node SubIface, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *SubImpl: + return a.rewriteRefOfSubImpl(parent, node, replacer) + default: + // this should never happen + return nil + } +} +func (a *application) rewriteValueContainer(parent AST, node ValueContainer, replacer replacerFunc) error { + var err error + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteAST(node, node.ASTType, func(newNode, parent AST) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTType' on 'ValueContainer'") + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLeaf(node, node.ASTImplementationType, func(newNode, parent AST) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTImplementationType' on 'ValueContainer'") + }); errF != nil { + return errF + } + if err != nil { + return err + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteValueSliceContainer(parent AST, node ValueSliceContainer, replacer replacerFunc) error { + var err error + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + for _, el := range node.ASTElements { + if errF := a.rewriteAST(node, el, func(newNode, parent AST) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTElements' on 'ValueSliceContainer'") + }); errF != nil { + return errF + } + } + for _, el := range node.ASTImplementationElements { + if errF := a.rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'ASTImplementationElements' on 'ValueSliceContainer'") + }); errF != nil { + return errF + } + } + if err != nil { + return err + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} diff --git a/go/tools/asthelpergen/integration/integration_rewriter_test.go b/go/tools/asthelpergen/integration/integration_rewriter_test.go index 9648abfef44..a5ad57ef9ab 100644 --- a/go/tools/asthelpergen/integration/integration_rewriter_test.go +++ b/go/tools/asthelpergen/integration/integration_rewriter_test.go @@ -21,6 +21,8 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) @@ -32,7 +34,8 @@ func TestRewriteVisitRefContainer(t *testing.T) { tv := &rewriteTestVisitor{} - Rewrite(containerContainer, tv.pre, tv.post) + _, err := Rewrite(containerContainer, tv.pre, tv.post) + require.NoError(t, err) expected := []step{ Pre{containerContainer}, @@ -55,7 +58,8 @@ func TestRewriteVisitValueContainer(t *testing.T) { tv := &rewriteTestVisitor{} - Rewrite(containerContainer, tv.pre, tv.post) + _, err := Rewrite(containerContainer, tv.pre, tv.post) + require.NoError(t, err) expected := []step{ Pre{containerContainer}, @@ -80,7 +84,8 @@ func TestRewriteVisitRefSliceContainer(t *testing.T) { tv := &rewriteTestVisitor{} - Rewrite(containerContainer, tv.pre, tv.post) + _, err := Rewrite(containerContainer, tv.pre, tv.post) + require.NoError(t, err) tv.assertEquals(t, []step{ Pre{containerContainer}, @@ -108,7 +113,8 @@ func TestRewriteVisitValueSliceContainer(t *testing.T) { tv := &rewriteTestVisitor{} - Rewrite(containerContainer, tv.pre, tv.post) + _, err := Rewrite(containerContainer, tv.pre, tv.post) + require.NoError(t, err) tv.assertEquals(t, []step{ Pre{containerContainer}, @@ -144,7 +150,8 @@ func TestRewriteVisitInterfaceSlice(t *testing.T) { tv := &rewriteTestVisitor{} - Rewrite(ast, tv.pre, tv.post) + _, err := Rewrite(ast, tv.pre, tv.post) + require.NoError(t, err) tv.assertEquals(t, []step{ Pre{ast}, @@ -169,20 +176,22 @@ func TestRewriteVisitRefContainerReplace(t *testing.T) { } // rewrite field of type AST - Rewrite(ast, func(cursor *Cursor) bool { + _, err := Rewrite(ast, func(cursor *Cursor) bool { leaf, ok := cursor.node.(*RefContainer) if ok && leaf.NotASTType == 12 { cursor.Replace(&Leaf{99}) } return true }, nil) + require.NoError(t, err) assert.Equal(t, &RefContainer{ ASTType: &Leaf{99}, ASTImplementationType: &Leaf{2}, }, ast) - Rewrite(ast, rewriteLeaf(2, 55), nil) + _, err = Rewrite(ast, rewriteLeaf(2, 55), nil) + require.NoError(t, err) assert.Equal(t, &RefContainer{ ASTType: &Leaf{99}, @@ -196,13 +205,7 @@ func TestRewriteVisitValueContainerReplace(t *testing.T) { ASTImplementationType: &Leaf{2}, } - defer func() { - if r := recover(); r != nil { - assert.Contains(t, r, "ValueContainer ASTType") - } - }() - - Rewrite(ast, func(cursor *Cursor) bool { + _, err := Rewrite(ast, func(cursor *Cursor) bool { leaf, ok := cursor.node.(ValueContainer) if ok && leaf.NotASTType == 12 { cursor.Replace(&Leaf{99}) @@ -210,7 +213,7 @@ func TestRewriteVisitValueContainerReplace(t *testing.T) { return true }, nil) - t.Fatalf("should not get here") + require.Error(t, err) } func TestRewriteVisitValueContainerReplace2(t *testing.T) { @@ -219,15 +222,36 @@ func TestRewriteVisitValueContainerReplace2(t *testing.T) { ASTImplementationType: &Leaf{2}, } - defer func() { - if r := recover(); r != nil { - assert.Contains(t, r, "ValueContainer ASTImplementationType") - } - }() + _, err := Rewrite(ast, rewriteLeaf(2, 10), nil) + require.Error(t, err) +} + +func TestRewriteVisitRefContainerPreOrPostOnly(t *testing.T) { + leaf1 := &Leaf{1} + leaf2 := &Leaf{2} + container := &RefContainer{ASTType: leaf1, ASTImplementationType: leaf2} + containerContainer := &RefContainer{ASTType: container} + + tv := &rewriteTestVisitor{} - Rewrite(ast, rewriteLeaf(2, 10), nil) + _, err := Rewrite(containerContainer, tv.pre, nil) + require.NoError(t, err) + tv.assertEquals(t, []step{ + Pre{containerContainer}, + Pre{container}, + Pre{leaf1}, + Pre{leaf2}, + }) - t.Fatalf("should not get here") + tv = &rewriteTestVisitor{} + _, err = Rewrite(containerContainer, nil, tv.post) + require.NoError(t, err) + tv.assertEquals(t, []step{ + Post{leaf1}, + Post{leaf2}, + Post{container}, + Post{containerContainer}, + }) } func rewriteLeaf(from, to int) func(*Cursor) bool { @@ -246,14 +270,16 @@ func TestRefSliceContainerReplace(t *testing.T) { ASTImplementationElements: []*Leaf{{3}, {4}}, } - Rewrite(ast, rewriteLeaf(2, 42), nil) + _, err := Rewrite(ast, rewriteLeaf(2, 42), nil) + require.NoError(t, err) assert.Equal(t, &RefSliceContainer{ ASTElements: []AST{&Leaf{1}, &Leaf{42}}, ASTImplementationElements: []*Leaf{{3}, {4}}, }, ast) - Rewrite(ast, rewriteLeaf(3, 88), nil) + _, err = Rewrite(ast, rewriteLeaf(3, 88), nil) + require.NoError(t, err) assert.Equal(t, &RefSliceContainer{ ASTElements: []AST{&Leaf{1}, &Leaf{42}}, @@ -272,7 +298,7 @@ func (r Pre) String() string { return fmt.Sprintf("Pre(%s)", r.el.String()) } func (r Post) String() string { - return fmt.Sprintf("Pre(%s)", r.el.String()) + return fmt.Sprintf("Post(%s)", r.el.String()) } type Post struct { @@ -327,63 +353,3 @@ func (tv *rewriteTestVisitor) assertEquals(t *testing.T, expected []step) { } } - -// below follows two different ways of creating the replacement method for slices, and benchmark -// between them. Diff seems to be very small, so I'll use the most readable form -type replaceA int - -func (r *replaceA) replace(newNode, container AST) { - container.(InterfaceSlice)[int(*r)] = newNode.(AST) -} - -func (r *replaceA) inc() { - *r++ -} - -func replaceB(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(InterfaceSlice)[idx] = newNode.(AST) - } -} - -func BenchmarkSliceReplacerA(b *testing.B) { - islice := make(InterfaceSlice, 20) - for i := range islice { - islice[i] = &Leaf{i} - } - a := &application{ - pre: func(c *Cursor) bool { - return true - }, - post: nil, - cursor: Cursor{}, - } - - for i := 0; i < b.N; i++ { - replacer := replaceA(0) - for _, el := range islice { - a.apply(islice, el, replacer.replace) - replacer.inc() - } - } -} - -func BenchmarkSliceReplacerB(b *testing.B) { - islice := make(InterfaceSlice, 20) - for i := range islice { - islice[i] = &Leaf{i} - } - a := &application{ - pre: func(c *Cursor) bool { - return true - }, - post: nil, - cursor: Cursor{}, - } - - for i := 0; i < b.N; i++ { - for x, el := range islice { - a.apply(islice, el, replaceB(x)) - } - } -} diff --git a/go/tools/asthelpergen/integration/rewriter.go b/go/tools/asthelpergen/integration/rewriter.go deleted file mode 100644 index 300ccef16ea..00000000000 --- a/go/tools/asthelpergen/integration/rewriter.go +++ /dev/null @@ -1,102 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -// Code generated by ASTHelperGen. DO NOT EDIT. - -package integration - -func (a *application) apply(parent, node AST, replacer replacerFunc) { - if node == nil || isNilValue(node) { - return - } - saved := a.cursor - a.cursor.replacer = replacer - a.cursor.node = node - a.cursor.parent = parent - if a.pre != nil && !a.pre(&a.cursor) { - a.cursor = saved - return - } - switch n := node.(type) { - case Bytes: - case InterfaceContainer: - case InterfaceSlice: - for x, el := range n { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(InterfaceSlice)[idx] = newNode.(AST) - } - }(x)) - } - case *Leaf: - case LeafSlice: - for x, el := range n { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(LeafSlice)[idx] = newNode.(*Leaf) - } - }(x)) - } - case *NoCloneType: - case *RefContainer: - a.apply(node, n.ASTType, func(newNode, parent AST) { - parent.(*RefContainer).ASTType = newNode.(AST) - }) - a.apply(node, n.ASTImplementationType, func(newNode, parent AST) { - parent.(*RefContainer).ASTImplementationType = newNode.(*Leaf) - }) - case *RefSliceContainer: - for x, el := range n.ASTElements { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(*RefSliceContainer).ASTElements[idx] = newNode.(AST) - } - }(x)) - } - for x, el := range n.ASTImplementationElements { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(*RefSliceContainer).ASTImplementationElements[idx] = newNode.(*Leaf) - } - }(x)) - } - case *SubImpl: - a.apply(node, n.inner, func(newNode, parent AST) { - parent.(*SubImpl).inner = newNode.(SubIface) - }) - case ValueContainer: - a.apply(node, n.ASTType, replacePanic("ValueContainer ASTType")) - a.apply(node, n.ASTImplementationType, replacePanic("ValueContainer ASTImplementationType")) - case ValueSliceContainer: - for x, el := range n.ASTElements { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(ValueSliceContainer).ASTElements[idx] = newNode.(AST) - } - }(x)) - } - for x, el := range n.ASTImplementationElements { - a.apply(node, el, func(idx int) func(AST, AST) { - return func(newNode, container AST) { - container.(ValueSliceContainer).ASTImplementationElements[idx] = newNode.(*Leaf) - } - }(x)) - } - } - if a.post != nil && !a.post(&a.cursor) { - panic(abort) - } - a.cursor = saved -} diff --git a/go/tools/asthelpergen/integration/test_helpers.go b/go/tools/asthelpergen/integration/test_helpers.go index 3a2da19be80..6ca3df82bde 100644 --- a/go/tools/asthelpergen/integration/test_helpers.go +++ b/go/tools/asthelpergen/integration/test_helpers.go @@ -17,7 +17,6 @@ limitations under the License. package integration import ( - "reflect" "strings" ) @@ -40,11 +39,6 @@ func sliceStringLeaf(els ...*Leaf) string { // the methods below are what the generated code expected to be there in the package -type application struct { - pre, post ApplyFunc - cursor Cursor -} - type ApplyFunc func(*Cursor) bool type Cursor struct { @@ -68,30 +62,20 @@ func (c *Cursor) Replace(newNode AST) { type replacerFunc func(newNode, parent AST) -func isNilValue(i interface{}) bool { - valueOf := reflect.ValueOf(i) - kind := valueOf.Kind() - isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice - return isNullable && valueOf.IsNil() -} - -var abort = new(int) // singleton, to signal termination of Apply - -func Rewrite(node AST, pre, post ApplyFunc) (result AST) { - parent := &struct{ AST }{node} +func Rewrite(node AST, pre, post ApplyFunc) (AST, error) { + outer := &struct{ AST }{node} a := &application{ - pre: pre, - post: post, - cursor: Cursor{}, + pre: pre, + post: post, } - a.apply(parent.AST, node, nil) - return parent.AST -} + err := a.rewriteAST(outer, node, func(newNode, parent AST) { + outer.AST = newNode + }) -func replacePanic(msg string) func(newNode, parent AST) { - return func(newNode, parent AST) { - panic("Tried replacing a field of a value type. This is not supported. " + msg) + if err != nil { + return nil, err } + return outer.AST, nil } diff --git a/go/tools/asthelpergen/integration/types.go b/go/tools/asthelpergen/integration/types.go index ae9fa38f01a..020a777248c 100644 --- a/go/tools/asthelpergen/integration/types.go +++ b/go/tools/asthelpergen/integration/types.go @@ -173,3 +173,10 @@ func (r *NoCloneType) String() string { } type Visit func(node AST) (bool, error) + +var errAbort = fmt.Errorf("this error is to abort the rewriter, it is not an actual error") + +type application struct { + pre, post ApplyFunc + cur Cursor +} diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go new file mode 100644 index 00000000000..42a33ab4b60 --- /dev/null +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -0,0 +1,331 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package asthelpergen + +import ( + "fmt" + "go/types" + + "github.com/dave/jennifer/jen" +) + +const ( + rewriteName = "rewrite" + abort = "errAbort" +) + +type rewriteGen struct { + ifaceName string +} + +var _ generator2 = (*rewriteGen)(nil) + +func (e rewriteGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + /* + func VisitAST(in AST) (bool, error) { + if in == nil { + return false, nil + } + switch a := inA.(type) { + case *SubImpl: + return VisitSubImpl(a, b) + default: + return false, nil + } + } + */ + stmts := []jen.Code{ + jen.If(jen.Id("node == nil").Block(returnNil())), + } + + var cases []jen.Code + _ = spi.findImplementations(iface, func(t types.Type) error { + if _, ok := t.Underlying().(*types.Interface); ok { + return nil + } + typeString := types.TypeString(t, noQualifier) + funcName := rewriteName + printableTypeName(t) + spi.addType(t) + caseBlock := jen.Case(jen.Id(typeString)).Block( + jen.Return(jen.Id("a").Dot(funcName).Call(jen.Id("parent, node, replacer"))), + ) + cases = append(cases, caseBlock) + return nil + }) + + cases = append(cases, + jen.Default().Block( + jen.Comment("this should never happen"), + returnNil(), + )) + + stmts = append(stmts, jen.Switch(jen.Id("node := node.(type)").Block( + cases..., + ))) + + e.rewriteFunc(t, stmts, spi) + return nil +} + +func (e rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + fields := e.rewriteAllStructFields(t, strct, spi, true) + + stmts := []jen.Code{jen.Var().Id("err").Error()} + stmts = append(stmts, executePre()) + stmts = append(stmts, fields...) + stmts = append(stmts, jen.If(jen.Id("err != nil")).Block(jen.Return(jen.Err()))) + stmts = append(stmts, executePost(len(fields) > 0)) + stmts = append(stmts, returnNil()) + + e.rewriteFunc(t, stmts, spi) + + return nil +} + +func (e rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + + /* + if node == nil { return nil } + */ + stmts := []jen.Code{jen.If(jen.Id("node == nil").Block(returnNil()))} + + /* + if !pre(&cur) { + return nil + } + */ + stmts = append(stmts, executePre()) + fields := e.rewriteAllStructFields(t, strct, spi, false) + stmts = append(stmts, fields...) + stmts = append(stmts, executePost(len(fields) > 0)) + stmts = append(stmts, returnNil()) + + e.rewriteFunc(t, stmts, spi) + + return nil +} + +func (e rewriteGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + + /* + */ + + stmts := []jen.Code{ + jen.Comment("ptrToBasicMethod"), + } + e.rewriteFunc(t, stmts, spi) + + return nil +} + +func (e rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + + /* + if node == nil { + return nil + } + cur := Cursor{ + node: node, + parent: parent, + replacer: replacer, + } + if !pre(&cur) { + return nil + } + */ + stmts := []jen.Code{ + jen.If(jen.Id("node == nil").Block(returnNil())), + } + stmts = append(stmts, executePre()) + + haveChildren := false + if shouldAdd(slice.Elem(), spi.iface()) { + /* + for i, el := range node { + if err := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(LeafSlice)[i] = newNode.(*Leaf) + }, pre, post); err != nil { + return err + } + } + */ + haveChildren = true + stmts = append(stmts, + jen.For(jen.Id("i, el").Op(":=").Id("range node")). + Block(e.rewriteChild(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("i")), false))) + } + + stmts = append(stmts, executePost(haveChildren)) + stmts = append(stmts, returnNil()) + + e.rewriteFunc(t, stmts, spi) + return nil +} + +func setupCursor() []jen.Code { + return []jen.Code{ + jen.Id("a.cur.replacer = replacer"), + jen.Id("a.cur.parent = parent"), + jen.Id("a.cur.node = node"), + } +} +func executePre() jen.Code { + curStmts := setupCursor() + curStmts = append(curStmts, jen.If(jen.Id("!a.pre(&a.cur)")).Block(returnNil())) + return jen.If(jen.Id("a.pre!= nil").Block(curStmts...)) +} + +func executePost(seenChildren bool) jen.Code { + var curStmts []jen.Code + if seenChildren { + // if we have visited children, we have to write to the cursor fields + curStmts = setupCursor() + } else { + curStmts = append(curStmts, + jen.If(jen.Id("a.pre == nil")).Block(setupCursor()...)) + } + + curStmts = append(curStmts, jen.If(jen.Id("!a.post(&a.cur)")).Block(jen.Return(jen.Id(abort)))) + + return jen.If(jen.Id("a.post != nil")).Block(curStmts...) +} + +func (e rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { + if !shouldAdd(t, spi.iface()) { + return nil + } + + stmts := []jen.Code{executePre(), executePost(false), returnNil()} + e.rewriteFunc(t, stmts, spi) + return nil +} + +func (e rewriteGen) rewriteFunc(t types.Type, stmts []jen.Code, spi generatorSPI) { + + /* + func (a *application) rewriteNodeType(parent AST, node NodeType, replacer replacerFunc) { + */ + + typeString := types.TypeString(t, noQualifier) + funcName := fmt.Sprintf("%s%s", rewriteName, printableTypeName(t)) + code := jen.Func().Params( + jen.Id("a").Op("*").Id("application"), + ).Id(funcName).Params( + jen.Id(fmt.Sprintf("parent %s, node %s, replacer replacerFunc", e.ifaceName, typeString)), + ).Error().Block(stmts...) + + spi.addFunc(funcName, rewrite, code) +} + +func (e rewriteGen) rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI, fail bool) []jen.Code { + /* + if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + err = vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] tried to replace '%s' on '%s'") + }, pre, post); errF != nil { + return errF + } + + */ + var output []jen.Code + for i := 0; i < strct.NumFields(); i++ { + field := strct.Field(i) + if types.Implements(field.Type(), spi.iface()) { + spi.addType(field.Type()) + output = append(output, e.rewriteChild(t, field.Type(), field.Name(), jen.Id("node").Dot(field.Name()), jen.Dot(field.Name()), fail)) + continue + } + slice, isSlice := field.Type().(*types.Slice) + if isSlice && types.Implements(slice.Elem(), spi.iface()) { + spi.addType(slice.Elem()) + id := jen.Id("i") + if fail { + id = jen.Id("_") + } + output = append(output, + jen.For(jen.List(id, jen.Id("el")).Op(":=").Id("range node."+field.Name())). + Block(e.rewriteChild(t, slice.Elem(), field.Name(), jen.Id("el"), jen.Dot(field.Name()).Index(id), fail))) + } + } + return output +} + +func failReplacer(t types.Type, f string) *jen.Statement { + typeString := types.TypeString(t, noQualifier) + return jen.Err().Op("=").Qual("vitess.io/vitess/go/vt/vterrors", "New").Call( + jen.Qual("vitess.io/vitess/go/vt/proto/vtrpc", "Code_INTERNAL"), + jen.Lit(fmt.Sprintf("[BUG] tried to replace '%s' on '%s'", f, typeString)), + ) +} + +func (e rewriteGen) rewriteChild(t, field types.Type, fieldName string, param jen.Code, replace jen.Code, fail bool) jen.Code { + /* + if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { + parent.(*RefContainer).ASTType = newNode.(AST) + }, pre, post); errF != nil { + return errF + } + + if errF := rewriteAST(node, el, func(newNode, parent AST) { + parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) + }, pre, post); errF != nil { + return errF + } + + */ + funcName := rewriteName + printableTypeName(field) + var replaceOrFail *jen.Statement + if fail { + replaceOrFail = failReplacer(t, fieldName) + } else { + replaceOrFail = jen.Id("parent"). + Assert(jen.Id(types.TypeString(t, noQualifier))). + Add(replace). + Op("="). + Id("newNode").Assert(jen.Id(types.TypeString(field, noQualifier))) + + } + funcBlock := jen.Func().Call(jen.Id("newNode, parent").Id(e.ifaceName)). + Block(replaceOrFail) + + rewriteField := jen.If( + jen.Id("errF := ").Id("a").Dot(funcName).Call( + jen.Id("node"), + param, + funcBlock), + jen.Id("errF != nil").Block(jen.Return(jen.Id("errF")))) + + return rewriteField +} + +var noQualifier = func(p *types.Package) string { + return "" +} diff --git a/go/tools/asthelpergen/rewriter_gen.go b/go/tools/asthelpergen/rewriter_gen.go deleted file mode 100644 index b7bedbe5cc8..00000000000 --- a/go/tools/asthelpergen/rewriter_gen.go +++ /dev/null @@ -1,209 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package asthelpergen - -import ( - "go/types" - - "github.com/dave/jennifer/jen" -) - -type rewriterGen struct { - cases []jen.Code - interestingType func(types.Type) bool - ifaceName string -} - -func newRewriterGen(f func(types.Type) bool, name string) *rewriterGen { - return &rewriterGen{interestingType: f, ifaceName: name} -} - -var noQualifier = func(p *types.Package) string { - return "" -} - -func (r *rewriterGen) visitStruct(t types.Type, stroct *types.Struct) error { - typeString := types.TypeString(t, noQualifier) - typeName := printableTypeName(t) - var caseStmts []jen.Code - for i := 0; i < stroct.NumFields(); i++ { - field := stroct.Field(i) - if r.interestingType(field.Type()) { - if _, ok := t.(*types.Pointer); ok { - function := r.createReplaceMethod(typeString, field) - caseStmts = append(caseStmts, caseStmtFor(field, function)) - } else { - caseStmts = append(caseStmts, casePanicStmtFor(field, typeName+" "+field.Name())) - } - } - sliceT, ok := field.Type().(*types.Slice) - if ok && r.interestingType(sliceT.Elem()) { // we have a field containing a slice of interesting elements - function := r.createReplacementMethod(t, sliceT.Elem(), jen.Dot(field.Name())) - caseStmts = append(caseStmts, caseStmtForSliceField(field, function)) - } - } - r.cases = append(r.cases, jen.Case(jen.Id(typeString)).Block(caseStmts...)) - return nil -} - -func (r *rewriterGen) visitInterface(types.Type, *types.Interface) error { - return nil // rewriter doesn't deal with interfaces -} - -func (r *rewriterGen) visitSlice(t types.Type, slice *types.Slice) error { - typeString := types.TypeString(t, noQualifier) - - var stmts []jen.Code - if r.interestingType(slice.Elem()) { - function := r.createReplacementMethod(t, slice.Elem(), jen.Empty()) - stmts = append(stmts, caseStmtForSlice(function)) - } - r.cases = append(r.cases, jen.Case(jen.Id(typeString)).Block(stmts...)) - return nil -} - -func caseStmtFor(field *types.Var, expr jen.Code) *jen.Statement { - // a.apply(node, node.Field, replacerMethod) - return jen.Id("a").Dot("apply").Call(jen.Id("node"), jen.Id("n").Dot(field.Name()), expr) -} - -func casePanicStmtFor(field *types.Var, name string) *jen.Statement { - return jen.Id("a").Dot("apply").Call(jen.Id("node"), jen.Id("n").Dot(field.Name()), jen.Id("replacePanic").Call(jen.Lit(name))) -} - -func caseStmtForSlice(function *jen.Statement) jen.Code { - return jen.For(jen.List(jen.Op("x"), jen.Id("el"))).Op(":=").Range().Id("n").Block( - jen.Id("a").Dot("apply").Call( - jen.Id("node"), - jen.Id("el"), - function, - ), - ) -} - -func caseStmtForSliceField(field *types.Var, function *jen.Statement) jen.Code { - //for x, el := range n { - return jen.For(jen.List(jen.Op("x"), jen.Id("el"))).Op(":=").Range().Id("n").Dot(field.Name()).Block( - jen.Id("a").Dot("apply").Call( - // a.apply(node, el, replaceInterfaceSlice(x)) - jen.Id("node"), - jen.Id("el"), - function, - ), - ) -} - -func (r *rewriterGen) structCase(name string, stroct *types.Struct) (jen.Code, error) { - var stmts []jen.Code - for i := 0; i < stroct.NumFields(); i++ { - field := stroct.Field(i) - if r.interestingType(field.Type()) { - stmts = append(stmts, jen.Id("a").Dot("apply").Call(jen.Id("node"), jen.Id("n").Dot(field.Name()), jen.Nil())) - } - } - return jen.Case(jen.Op("*").Id(name)).Block(stmts...), nil -} - -func (r *rewriterGen) createReplaceMethod(structType string, field *types.Var) jen.Code { - return jen.Func().Params( - jen.Id("newNode"), - jen.Id("parent").Id(r.ifaceName), - ).Block( - jen.Id("parent").Assert(jen.Id(structType)).Dot(field.Name()).Op("=").Id("newNode").Assert(jen.Id(types.TypeString(field.Type(), noQualifier))), - ) -} - -func (r *rewriterGen) createReplacementMethod(container, elem types.Type, x jen.Code) *jen.Statement { - /* - func replacer(idx int) func(AST, AST) { - return func(newnode, container AST) { - container.(InterfaceSlice)[idx] = newnode.(AST) - } - }(x) - */ - return jen.Func().Params(jen.Id("idx").Int()).Func().Params(jen.List(jen.Id(r.ifaceName), jen.Id(r.ifaceName))).Block( - jen.Return(jen.Func().Params(jen.List(jen.Id("newNode"), jen.Id("container")).Id(r.ifaceName))).Block( - jen.Id("container").Assert(jen.Id(types.TypeString(container, noQualifier))).Add(x).Index(jen.Id("idx")).Op("="). - Id("newNode").Assert(jen.Id(types.TypeString(elem, noQualifier))), - ), - ).Call(jen.Id("x")) -} - -func (r *rewriterGen) createFile(pkgName string) (string, *jen.File) { - out := jen.NewFile(pkgName) - out.HeaderComment(licenseFileHeader) - out.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.") - - out.Add( - // func (a *application) apply(parent, node SQLNode, replacer replacerFunc) { - jen.Func().Params( - jen.Id("a").Op("*").Id("application"), - ).Id("apply").Params( - jen.Id("parent"), - jen.Id("node").Id(r.ifaceName), - jen.Id("replacer").Id("replacerFunc"), - ).Block( - /* - if node == nil || isNilValue(node) { - return - } - */ - jen.If( - jen.Id("node").Op("==").Nil().Op("||"). - Id("isNilValue").Call(jen.Id("node"))).Block( - jen.Return(), - ), - /* - saved := a.cursor - a.cursor.replacer = replacer - a.cursor.node = node - a.cursor.parent = parent - */ - jen.Id("saved").Op(":=").Id("a").Dot("cursor"), - jen.Id("a").Dot("cursor").Dot("replacer").Op("=").Id("replacer"), - jen.Id("a").Dot("cursor").Dot("node").Op("=").Id("node"), - jen.Id("a").Dot("cursor").Dot("parent").Op("=").Id("parent"), - jen.If( - jen.Id("a").Dot("pre").Op("!=").Nil().Op("&&"). - Op("!").Id("a").Dot("pre").Call(jen.Op("&").Id("a").Dot("cursor"))).Block( - jen.Id("a").Dot("cursor").Op("=").Id("saved"), - jen.Return(), - ), - - // switch n := node.(type) { - jen.Switch(jen.Id("n").Op(":=").Id("node").Assert(jen.Id("type")).Block( - r.cases..., - )), - - /* - if a.post != nil && !a.post(&a.cursor) { - panic(abort) - } - */ - jen.If( - jen.Id("a").Dot("post").Op("!=").Nil().Op("&&"). - Op("!").Id("a").Dot("post").Call(jen.Op("&").Id("a").Dot("cursor"))).Block( - jen.Id("panic").Call(jen.Id("abort")), - ), - - // a.cursor = saved - jen.Id("a").Dot("cursor").Op("=").Id("saved"), - ), - ) - - return "rewriter.go", out -} diff --git a/go/tools/asthelpergen/visit_gen.go b/go/tools/asthelpergen/visit_gen.go index 325d71aceb2..04dddcae0ec 100644 --- a/go/tools/asthelpergen/visit_gen.go +++ b/go/tools/asthelpergen/visit_gen.go @@ -181,20 +181,6 @@ func (e visitGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI return nil } -func (e visitGen) ptrToOtherMethod(t types.Type, _ *types.Pointer, spi generatorSPI) error { - if !shouldAdd(t, spi.iface()) { - return nil - } - - stmts := []jen.Code{ - jen.Comment("ptrToOtherMethod "), - } - - visitFunc(t, stmts, spi) - - return nil -} - func (e visitGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error { if !shouldAdd(t, spi.iface()) { return nil diff --git a/go/vt/sqlparser/ast_helper.go b/go/vt/sqlparser/ast_helper.go index 954ecacc58c..8172a93fdb6 100644 --- a/go/vt/sqlparser/ast_helper.go +++ b/go/vt/sqlparser/ast_helper.go @@ -17,1521 +17,358 @@ limitations under the License. package sqlparser -// EqualsSQLNode does deep equals between the two objects. -func EqualsSQLNode(inA, inB SQLNode) bool { - if inA == nil && inB == nil { - return true - } - if inA == nil || inB == nil { - return false +import ( + vtrpc "vitess.io/vitess/go/vt/proto/vtrpc" + vterrors "vitess.io/vitess/go/vt/vterrors" +) + +// CloneAlterOption creates a deep clone of the input. +func CloneAlterOption(in AlterOption) AlterOption { + if in == nil { + return nil } - switch a := inA.(type) { - case AccessMode: - b, ok := inB.(AccessMode) - if !ok { - return false - } - return a == b + switch in := in.(type) { case *AddColumns: - b, ok := inB.(*AddColumns) - if !ok { - return false - } - return EqualsRefOfAddColumns(a, b) + return CloneRefOfAddColumns(in) case *AddConstraintDefinition: - b, ok := inB.(*AddConstraintDefinition) - if !ok { - return false - } - return EqualsRefOfAddConstraintDefinition(a, b) + return CloneRefOfAddConstraintDefinition(in) case *AddIndexDefinition: - b, ok := inB.(*AddIndexDefinition) - if !ok { - return false - } - return EqualsRefOfAddIndexDefinition(a, b) + return CloneRefOfAddIndexDefinition(in) case AlgorithmValue: - b, ok := inB.(AlgorithmValue) - if !ok { - return false - } - return a == b - case *AliasedExpr: - b, ok := inB.(*AliasedExpr) - if !ok { - return false - } - return EqualsRefOfAliasedExpr(a, b) - case *AliasedTableExpr: - b, ok := inB.(*AliasedTableExpr) - if !ok { - return false - } - return EqualsRefOfAliasedTableExpr(a, b) + return in case *AlterCharset: - b, ok := inB.(*AlterCharset) - if !ok { - return false - } - return EqualsRefOfAlterCharset(a, b) + return CloneRefOfAlterCharset(in) case *AlterColumn: - b, ok := inB.(*AlterColumn) - if !ok { - return false - } - return EqualsRefOfAlterColumn(a, b) + return CloneRefOfAlterColumn(in) + case *ChangeColumn: + return CloneRefOfChangeColumn(in) + case *DropColumn: + return CloneRefOfDropColumn(in) + case *DropKey: + return CloneRefOfDropKey(in) + case *Force: + return CloneRefOfForce(in) + case *KeyState: + return CloneRefOfKeyState(in) + case *LockOption: + return CloneRefOfLockOption(in) + case *ModifyColumn: + return CloneRefOfModifyColumn(in) + case *OrderByOption: + return CloneRefOfOrderByOption(in) + case *RenameIndex: + return CloneRefOfRenameIndex(in) + case *RenameTableName: + return CloneRefOfRenameTableName(in) + case TableOptions: + return CloneTableOptions(in) + case *TablespaceOperation: + return CloneRefOfTablespaceOperation(in) + case *Validation: + return CloneRefOfValidation(in) + default: + // this should never happen + return nil + } +} + +// CloneCharacteristic creates a deep clone of the input. +func CloneCharacteristic(in Characteristic) Characteristic { + if in == nil { + return nil + } + switch in := in.(type) { + case AccessMode: + return in + case IsolationLevel: + return in + default: + // this should never happen + return nil + } +} + +// CloneColIdent creates a deep clone of the input. +func CloneColIdent(n ColIdent) ColIdent { + return *CloneRefOfColIdent(&n) +} + +// CloneColTuple creates a deep clone of the input. +func CloneColTuple(in ColTuple) ColTuple { + if in == nil { + return nil + } + switch in := in.(type) { + case ListArg: + return CloneListArg(in) + case *Subquery: + return CloneRefOfSubquery(in) + case ValTuple: + return CloneValTuple(in) + default: + // this should never happen + return nil + } +} + +// CloneCollateAndCharset creates a deep clone of the input. +func CloneCollateAndCharset(n CollateAndCharset) CollateAndCharset { + return *CloneRefOfCollateAndCharset(&n) +} + +// CloneColumnType creates a deep clone of the input. +func CloneColumnType(n ColumnType) ColumnType { + return *CloneRefOfColumnType(&n) +} + +// CloneColumns creates a deep clone of the input. +func CloneColumns(n Columns) Columns { + res := make(Columns, 0, len(n)) + for _, x := range n { + res = append(res, CloneColIdent(x)) + } + return res +} + +// CloneComments creates a deep clone of the input. +func CloneComments(n Comments) Comments { + res := make(Comments, 0, len(n)) + copy(res, n) + return res +} + +// CloneConstraintInfo creates a deep clone of the input. +func CloneConstraintInfo(in ConstraintInfo) ConstraintInfo { + if in == nil { + return nil + } + switch in := in.(type) { + case *CheckConstraintDefinition: + return CloneRefOfCheckConstraintDefinition(in) + case *ForeignKeyDefinition: + return CloneRefOfForeignKeyDefinition(in) + default: + // this should never happen + return nil + } +} + +// CloneDBDDLStatement creates a deep clone of the input. +func CloneDBDDLStatement(in DBDDLStatement) DBDDLStatement { + if in == nil { + return nil + } + switch in := in.(type) { case *AlterDatabase: - b, ok := inB.(*AlterDatabase) - if !ok { - return false - } - return EqualsRefOfAlterDatabase(a, b) - case *AlterMigration: - b, ok := inB.(*AlterMigration) - if !ok { - return false - } - return EqualsRefOfAlterMigration(a, b) + return CloneRefOfAlterDatabase(in) + case *CreateDatabase: + return CloneRefOfCreateDatabase(in) + case *DropDatabase: + return CloneRefOfDropDatabase(in) + default: + // this should never happen + return nil + } +} + +// CloneDDLStatement creates a deep clone of the input. +func CloneDDLStatement(in DDLStatement) DDLStatement { + if in == nil { + return nil + } + switch in := in.(type) { case *AlterTable: - b, ok := inB.(*AlterTable) - if !ok { - return false - } - return EqualsRefOfAlterTable(a, b) + return CloneRefOfAlterTable(in) case *AlterView: - b, ok := inB.(*AlterView) - if !ok { - return false - } - return EqualsRefOfAlterView(a, b) - case *AlterVschema: - b, ok := inB.(*AlterVschema) - if !ok { - return false - } - return EqualsRefOfAlterVschema(a, b) - case *AndExpr: - b, ok := inB.(*AndExpr) - if !ok { - return false - } - return EqualsRefOfAndExpr(a, b) - case Argument: - b, ok := inB.(Argument) - if !ok { - return false - } - return a == b - case *AutoIncSpec: - b, ok := inB.(*AutoIncSpec) - if !ok { - return false - } - return EqualsRefOfAutoIncSpec(a, b) - case *Begin: - b, ok := inB.(*Begin) - if !ok { - return false - } - return EqualsRefOfBegin(a, b) + return CloneRefOfAlterView(in) + case *CreateTable: + return CloneRefOfCreateTable(in) + case *CreateView: + return CloneRefOfCreateView(in) + case *DropTable: + return CloneRefOfDropTable(in) + case *DropView: + return CloneRefOfDropView(in) + case *RenameTable: + return CloneRefOfRenameTable(in) + case *TruncateTable: + return CloneRefOfTruncateTable(in) + default: + // this should never happen + return nil + } +} + +// CloneExplain creates a deep clone of the input. +func CloneExplain(in Explain) Explain { + if in == nil { + return nil + } + switch in := in.(type) { + case *ExplainStmt: + return CloneRefOfExplainStmt(in) + case *ExplainTab: + return CloneRefOfExplainTab(in) + default: + // this should never happen + return nil + } +} + +// CloneExpr creates a deep clone of the input. +func CloneExpr(in Expr) Expr { + if in == nil { + return nil + } + switch in := in.(type) { + case *AndExpr: + return CloneRefOfAndExpr(in) + case Argument: + return in case *BinaryExpr: - b, ok := inB.(*BinaryExpr) - if !ok { - return false - } - return EqualsRefOfBinaryExpr(a, b) + return CloneRefOfBinaryExpr(in) case BoolVal: - b, ok := inB.(BoolVal) - if !ok { - return false - } - return a == b - case *CallProc: - b, ok := inB.(*CallProc) - if !ok { - return false - } - return EqualsRefOfCallProc(a, b) + return in case *CaseExpr: - b, ok := inB.(*CaseExpr) - if !ok { - return false - } - return EqualsRefOfCaseExpr(a, b) - case *ChangeColumn: - b, ok := inB.(*ChangeColumn) - if !ok { - return false - } - return EqualsRefOfChangeColumn(a, b) - case *CheckConstraintDefinition: - b, ok := inB.(*CheckConstraintDefinition) - if !ok { - return false - } - return EqualsRefOfCheckConstraintDefinition(a, b) - case ColIdent: - b, ok := inB.(ColIdent) - if !ok { - return false - } - return EqualsColIdent(a, b) + return CloneRefOfCaseExpr(in) case *ColName: - b, ok := inB.(*ColName) - if !ok { - return false - } - return EqualsRefOfColName(a, b) + return CloneRefOfColName(in) case *CollateExpr: - b, ok := inB.(*CollateExpr) - if !ok { - return false - } - return EqualsRefOfCollateExpr(a, b) - case *ColumnDefinition: - b, ok := inB.(*ColumnDefinition) - if !ok { - return false - } - return EqualsRefOfColumnDefinition(a, b) - case *ColumnType: - b, ok := inB.(*ColumnType) - if !ok { - return false - } - return EqualsRefOfColumnType(a, b) - case Columns: - b, ok := inB.(Columns) - if !ok { - return false - } - return EqualsColumns(a, b) - case Comments: - b, ok := inB.(Comments) - if !ok { - return false - } - return EqualsComments(a, b) - case *Commit: - b, ok := inB.(*Commit) - if !ok { - return false - } - return EqualsRefOfCommit(a, b) + return CloneRefOfCollateExpr(in) case *ComparisonExpr: - b, ok := inB.(*ComparisonExpr) - if !ok { - return false - } - return EqualsRefOfComparisonExpr(a, b) - case *ConstraintDefinition: - b, ok := inB.(*ConstraintDefinition) - if !ok { - return false - } - return EqualsRefOfConstraintDefinition(a, b) + return CloneRefOfComparisonExpr(in) case *ConvertExpr: - b, ok := inB.(*ConvertExpr) - if !ok { - return false - } - return EqualsRefOfConvertExpr(a, b) - case *ConvertType: - b, ok := inB.(*ConvertType) - if !ok { - return false - } - return EqualsRefOfConvertType(a, b) + return CloneRefOfConvertExpr(in) case *ConvertUsingExpr: - b, ok := inB.(*ConvertUsingExpr) - if !ok { - return false - } - return EqualsRefOfConvertUsingExpr(a, b) - case *CreateDatabase: - b, ok := inB.(*CreateDatabase) - if !ok { - return false - } - return EqualsRefOfCreateDatabase(a, b) - case *CreateTable: - b, ok := inB.(*CreateTable) - if !ok { - return false - } - return EqualsRefOfCreateTable(a, b) - case *CreateView: - b, ok := inB.(*CreateView) - if !ok { - return false - } - return EqualsRefOfCreateView(a, b) + return CloneRefOfConvertUsingExpr(in) case *CurTimeFuncExpr: - b, ok := inB.(*CurTimeFuncExpr) - if !ok { - return false - } - return EqualsRefOfCurTimeFuncExpr(a, b) + return CloneRefOfCurTimeFuncExpr(in) case *Default: - b, ok := inB.(*Default) - if !ok { - return false - } - return EqualsRefOfDefault(a, b) - case *Delete: - b, ok := inB.(*Delete) - if !ok { - return false - } - return EqualsRefOfDelete(a, b) - case *DerivedTable: - b, ok := inB.(*DerivedTable) - if !ok { - return false - } - return EqualsRefOfDerivedTable(a, b) - case *DropColumn: - b, ok := inB.(*DropColumn) - if !ok { - return false - } - return EqualsRefOfDropColumn(a, b) - case *DropDatabase: - b, ok := inB.(*DropDatabase) - if !ok { - return false - } - return EqualsRefOfDropDatabase(a, b) - case *DropKey: - b, ok := inB.(*DropKey) - if !ok { - return false - } - return EqualsRefOfDropKey(a, b) - case *DropTable: - b, ok := inB.(*DropTable) - if !ok { - return false - } - return EqualsRefOfDropTable(a, b) - case *DropView: - b, ok := inB.(*DropView) - if !ok { - return false - } - return EqualsRefOfDropView(a, b) + return CloneRefOfDefault(in) case *ExistsExpr: - b, ok := inB.(*ExistsExpr) - if !ok { - return false - } - return EqualsRefOfExistsExpr(a, b) - case *ExplainStmt: - b, ok := inB.(*ExplainStmt) - if !ok { - return false - } - return EqualsRefOfExplainStmt(a, b) - case *ExplainTab: - b, ok := inB.(*ExplainTab) - if !ok { - return false - } - return EqualsRefOfExplainTab(a, b) - case Exprs: - b, ok := inB.(Exprs) - if !ok { - return false - } - return EqualsExprs(a, b) - case *Flush: - b, ok := inB.(*Flush) - if !ok { - return false - } - return EqualsRefOfFlush(a, b) - case *Force: - b, ok := inB.(*Force) - if !ok { - return false - } - return EqualsRefOfForce(a, b) - case *ForeignKeyDefinition: - b, ok := inB.(*ForeignKeyDefinition) - if !ok { - return false - } - return EqualsRefOfForeignKeyDefinition(a, b) + return CloneRefOfExistsExpr(in) case *FuncExpr: - b, ok := inB.(*FuncExpr) - if !ok { - return false - } - return EqualsRefOfFuncExpr(a, b) - case GroupBy: - b, ok := inB.(GroupBy) - if !ok { - return false - } - return EqualsGroupBy(a, b) + return CloneRefOfFuncExpr(in) case *GroupConcatExpr: - b, ok := inB.(*GroupConcatExpr) - if !ok { - return false - } - return EqualsRefOfGroupConcatExpr(a, b) - case *IndexDefinition: - b, ok := inB.(*IndexDefinition) - if !ok { - return false - } - return EqualsRefOfIndexDefinition(a, b) - case *IndexHints: - b, ok := inB.(*IndexHints) - if !ok { - return false - } - return EqualsRefOfIndexHints(a, b) - case *IndexInfo: - b, ok := inB.(*IndexInfo) - if !ok { - return false - } - return EqualsRefOfIndexInfo(a, b) - case *Insert: - b, ok := inB.(*Insert) - if !ok { - return false - } - return EqualsRefOfInsert(a, b) + return CloneRefOfGroupConcatExpr(in) case *IntervalExpr: - b, ok := inB.(*IntervalExpr) - if !ok { - return false - } - return EqualsRefOfIntervalExpr(a, b) + return CloneRefOfIntervalExpr(in) case *IsExpr: - b, ok := inB.(*IsExpr) - if !ok { - return false - } - return EqualsRefOfIsExpr(a, b) - case IsolationLevel: - b, ok := inB.(IsolationLevel) - if !ok { - return false - } - return a == b - case JoinCondition: - b, ok := inB.(JoinCondition) - if !ok { - return false - } - return EqualsJoinCondition(a, b) - case *JoinTableExpr: - b, ok := inB.(*JoinTableExpr) - if !ok { - return false - } - return EqualsRefOfJoinTableExpr(a, b) - case *KeyState: - b, ok := inB.(*KeyState) - if !ok { - return false - } - return EqualsRefOfKeyState(a, b) - case *Limit: - b, ok := inB.(*Limit) - if !ok { - return false - } - return EqualsRefOfLimit(a, b) + return CloneRefOfIsExpr(in) case ListArg: - b, ok := inB.(ListArg) - if !ok { - return false - } - return EqualsListArg(a, b) + return CloneListArg(in) case *Literal: - b, ok := inB.(*Literal) - if !ok { - return false - } - return EqualsRefOfLiteral(a, b) - case *Load: - b, ok := inB.(*Load) - if !ok { - return false - } - return EqualsRefOfLoad(a, b) - case *LockOption: - b, ok := inB.(*LockOption) - if !ok { - return false - } - return EqualsRefOfLockOption(a, b) - case *LockTables: - b, ok := inB.(*LockTables) - if !ok { - return false - } - return EqualsRefOfLockTables(a, b) + return CloneRefOfLiteral(in) case *MatchExpr: - b, ok := inB.(*MatchExpr) - if !ok { - return false - } - return EqualsRefOfMatchExpr(a, b) - case *ModifyColumn: - b, ok := inB.(*ModifyColumn) - if !ok { - return false - } - return EqualsRefOfModifyColumn(a, b) - case *Nextval: - b, ok := inB.(*Nextval) - if !ok { - return false - } - return EqualsRefOfNextval(a, b) + return CloneRefOfMatchExpr(in) case *NotExpr: - b, ok := inB.(*NotExpr) - if !ok { - return false - } - return EqualsRefOfNotExpr(a, b) + return CloneRefOfNotExpr(in) case *NullVal: - b, ok := inB.(*NullVal) - if !ok { - return false - } - return EqualsRefOfNullVal(a, b) - case OnDup: - b, ok := inB.(OnDup) - if !ok { - return false - } - return EqualsOnDup(a, b) - case *OptLike: - b, ok := inB.(*OptLike) - if !ok { - return false - } - return EqualsRefOfOptLike(a, b) + return CloneRefOfNullVal(in) case *OrExpr: - b, ok := inB.(*OrExpr) - if !ok { - return false - } - return EqualsRefOfOrExpr(a, b) - case *Order: - b, ok := inB.(*Order) - if !ok { - return false - } - return EqualsRefOfOrder(a, b) - case OrderBy: - b, ok := inB.(OrderBy) - if !ok { - return false - } - return EqualsOrderBy(a, b) - case *OrderByOption: - b, ok := inB.(*OrderByOption) - if !ok { - return false - } - return EqualsRefOfOrderByOption(a, b) - case *OtherAdmin: - b, ok := inB.(*OtherAdmin) - if !ok { - return false - } - return EqualsRefOfOtherAdmin(a, b) - case *OtherRead: - b, ok := inB.(*OtherRead) - if !ok { - return false - } - return EqualsRefOfOtherRead(a, b) - case *ParenSelect: - b, ok := inB.(*ParenSelect) - if !ok { - return false - } - return EqualsRefOfParenSelect(a, b) - case *ParenTableExpr: - b, ok := inB.(*ParenTableExpr) - if !ok { - return false - } - return EqualsRefOfParenTableExpr(a, b) - case *PartitionDefinition: - b, ok := inB.(*PartitionDefinition) - if !ok { - return false - } - return EqualsRefOfPartitionDefinition(a, b) - case *PartitionSpec: - b, ok := inB.(*PartitionSpec) - if !ok { - return false - } - return EqualsRefOfPartitionSpec(a, b) - case Partitions: - b, ok := inB.(Partitions) - if !ok { - return false - } - return EqualsPartitions(a, b) + return CloneRefOfOrExpr(in) case *RangeCond: - b, ok := inB.(*RangeCond) - if !ok { - return false - } - return EqualsRefOfRangeCond(a, b) - case ReferenceAction: - b, ok := inB.(ReferenceAction) - if !ok { - return false - } - return a == b - case *Release: - b, ok := inB.(*Release) - if !ok { - return false - } - return EqualsRefOfRelease(a, b) - case *RenameIndex: - b, ok := inB.(*RenameIndex) - if !ok { - return false - } - return EqualsRefOfRenameIndex(a, b) - case *RenameTable: - b, ok := inB.(*RenameTable) - if !ok { - return false - } - return EqualsRefOfRenameTable(a, b) - case *RenameTableName: - b, ok := inB.(*RenameTableName) - if !ok { - return false - } - return EqualsRefOfRenameTableName(a, b) - case *RevertMigration: - b, ok := inB.(*RevertMigration) - if !ok { - return false - } - return EqualsRefOfRevertMigration(a, b) - case *Rollback: - b, ok := inB.(*Rollback) - if !ok { - return false - } - return EqualsRefOfRollback(a, b) - case *SRollback: - b, ok := inB.(*SRollback) - if !ok { - return false - } - return EqualsRefOfSRollback(a, b) - case *Savepoint: - b, ok := inB.(*Savepoint) - if !ok { - return false - } - return EqualsRefOfSavepoint(a, b) - case *Select: - b, ok := inB.(*Select) - if !ok { - return false - } - return EqualsRefOfSelect(a, b) - case SelectExprs: - b, ok := inB.(SelectExprs) - if !ok { - return false - } - return EqualsSelectExprs(a, b) - case *SelectInto: - b, ok := inB.(*SelectInto) - if !ok { - return false - } - return EqualsRefOfSelectInto(a, b) - case *Set: - b, ok := inB.(*Set) - if !ok { - return false - } - return EqualsRefOfSet(a, b) - case *SetExpr: - b, ok := inB.(*SetExpr) - if !ok { - return false - } - return EqualsRefOfSetExpr(a, b) - case SetExprs: - b, ok := inB.(SetExprs) - if !ok { - return false - } - return EqualsSetExprs(a, b) - case *SetTransaction: - b, ok := inB.(*SetTransaction) - if !ok { - return false - } - return EqualsRefOfSetTransaction(a, b) - case *Show: - b, ok := inB.(*Show) - if !ok { - return false - } - return EqualsRefOfShow(a, b) - case *ShowBasic: - b, ok := inB.(*ShowBasic) - if !ok { - return false - } - return EqualsRefOfShowBasic(a, b) - case *ShowCreate: - b, ok := inB.(*ShowCreate) - if !ok { - return false - } - return EqualsRefOfShowCreate(a, b) - case *ShowFilter: - b, ok := inB.(*ShowFilter) - if !ok { - return false - } - return EqualsRefOfShowFilter(a, b) - case *ShowLegacy: - b, ok := inB.(*ShowLegacy) - if !ok { - return false - } - return EqualsRefOfShowLegacy(a, b) - case *StarExpr: - b, ok := inB.(*StarExpr) - if !ok { - return false - } - return EqualsRefOfStarExpr(a, b) - case *Stream: - b, ok := inB.(*Stream) - if !ok { - return false - } - return EqualsRefOfStream(a, b) + return CloneRefOfRangeCond(in) case *Subquery: - b, ok := inB.(*Subquery) - if !ok { - return false - } - return EqualsRefOfSubquery(a, b) + return CloneRefOfSubquery(in) case *SubstrExpr: - b, ok := inB.(*SubstrExpr) - if !ok { - return false - } - return EqualsRefOfSubstrExpr(a, b) - case TableExprs: - b, ok := inB.(TableExprs) - if !ok { - return false - } - return EqualsTableExprs(a, b) - case TableIdent: - b, ok := inB.(TableIdent) - if !ok { - return false - } - return EqualsTableIdent(a, b) - case TableName: - b, ok := inB.(TableName) - if !ok { - return false - } - return EqualsTableName(a, b) - case TableNames: - b, ok := inB.(TableNames) - if !ok { - return false - } - return EqualsTableNames(a, b) - case TableOptions: - b, ok := inB.(TableOptions) - if !ok { - return false - } - return EqualsTableOptions(a, b) - case *TableSpec: - b, ok := inB.(*TableSpec) - if !ok { - return false - } - return EqualsRefOfTableSpec(a, b) - case *TablespaceOperation: - b, ok := inB.(*TablespaceOperation) - if !ok { - return false - } - return EqualsRefOfTablespaceOperation(a, b) + return CloneRefOfSubstrExpr(in) case *TimestampFuncExpr: - b, ok := inB.(*TimestampFuncExpr) - if !ok { - return false - } - return EqualsRefOfTimestampFuncExpr(a, b) - case *TruncateTable: - b, ok := inB.(*TruncateTable) - if !ok { - return false - } - return EqualsRefOfTruncateTable(a, b) + return CloneRefOfTimestampFuncExpr(in) case *UnaryExpr: - b, ok := inB.(*UnaryExpr) - if !ok { - return false - } - return EqualsRefOfUnaryExpr(a, b) - case *Union: - b, ok := inB.(*Union) - if !ok { - return false - } - return EqualsRefOfUnion(a, b) - case *UnionSelect: - b, ok := inB.(*UnionSelect) - if !ok { - return false - } - return EqualsRefOfUnionSelect(a, b) - case *UnlockTables: - b, ok := inB.(*UnlockTables) - if !ok { - return false - } - return EqualsRefOfUnlockTables(a, b) - case *Update: - b, ok := inB.(*Update) - if !ok { - return false - } - return EqualsRefOfUpdate(a, b) - case *UpdateExpr: - b, ok := inB.(*UpdateExpr) - if !ok { - return false - } - return EqualsRefOfUpdateExpr(a, b) - case UpdateExprs: - b, ok := inB.(UpdateExprs) - if !ok { - return false - } - return EqualsUpdateExprs(a, b) - case *Use: - b, ok := inB.(*Use) - if !ok { - return false - } - return EqualsRefOfUse(a, b) - case *VStream: - b, ok := inB.(*VStream) - if !ok { - return false - } - return EqualsRefOfVStream(a, b) + return CloneRefOfUnaryExpr(in) case ValTuple: - b, ok := inB.(ValTuple) - if !ok { - return false - } - return EqualsValTuple(a, b) - case *Validation: - b, ok := inB.(*Validation) - if !ok { - return false - } - return EqualsRefOfValidation(a, b) - case Values: - b, ok := inB.(Values) - if !ok { - return false - } - return EqualsValues(a, b) + return CloneValTuple(in) case *ValuesFuncExpr: - b, ok := inB.(*ValuesFuncExpr) - if !ok { - return false - } - return EqualsRefOfValuesFuncExpr(a, b) - case VindexParam: - b, ok := inB.(VindexParam) - if !ok { - return false - } - return EqualsVindexParam(a, b) - case *VindexSpec: - b, ok := inB.(*VindexSpec) - if !ok { - return false - } - return EqualsRefOfVindexSpec(a, b) - case *When: - b, ok := inB.(*When) - if !ok { - return false - } - return EqualsRefOfWhen(a, b) - case *Where: - b, ok := inB.(*Where) - if !ok { - return false - } - return EqualsRefOfWhere(a, b) + return CloneRefOfValuesFuncExpr(in) case *XorExpr: - b, ok := inB.(*XorExpr) - if !ok { - return false - } - return EqualsRefOfXorExpr(a, b) + return CloneRefOfXorExpr(in) default: // this should never happen - return false + return nil } } -// CloneSQLNode creates a deep clone of the input. -func CloneSQLNode(in SQLNode) SQLNode { +// CloneExprs creates a deep clone of the input. +func CloneExprs(n Exprs) Exprs { + res := make(Exprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneExpr(x)) + } + return res +} + +// CloneGroupBy creates a deep clone of the input. +func CloneGroupBy(n GroupBy) GroupBy { + res := make(GroupBy, 0, len(n)) + for _, x := range n { + res = append(res, CloneExpr(x)) + } + return res +} + +// CloneInsertRows creates a deep clone of the input. +func CloneInsertRows(in InsertRows) InsertRows { if in == nil { return nil } switch in := in.(type) { - case AccessMode: - return in - case *AddColumns: - return CloneRefOfAddColumns(in) - case *AddConstraintDefinition: - return CloneRefOfAddConstraintDefinition(in) - case *AddIndexDefinition: - return CloneRefOfAddIndexDefinition(in) - case AlgorithmValue: - return in - case *AliasedExpr: - return CloneRefOfAliasedExpr(in) - case *AliasedTableExpr: - return CloneRefOfAliasedTableExpr(in) - case *AlterCharset: - return CloneRefOfAlterCharset(in) - case *AlterColumn: - return CloneRefOfAlterColumn(in) - case *AlterDatabase: - return CloneRefOfAlterDatabase(in) - case *AlterMigration: - return CloneRefOfAlterMigration(in) - case *AlterTable: - return CloneRefOfAlterTable(in) - case *AlterView: - return CloneRefOfAlterView(in) - case *AlterVschema: - return CloneRefOfAlterVschema(in) - case *AndExpr: - return CloneRefOfAndExpr(in) - case Argument: - return in - case *AutoIncSpec: - return CloneRefOfAutoIncSpec(in) - case *Begin: - return CloneRefOfBegin(in) - case *BinaryExpr: - return CloneRefOfBinaryExpr(in) - case BoolVal: - return in - case *CallProc: - return CloneRefOfCallProc(in) - case *CaseExpr: - return CloneRefOfCaseExpr(in) - case *ChangeColumn: - return CloneRefOfChangeColumn(in) - case *CheckConstraintDefinition: - return CloneRefOfCheckConstraintDefinition(in) - case ColIdent: - return CloneColIdent(in) - case *ColName: - return CloneRefOfColName(in) - case *CollateExpr: - return CloneRefOfCollateExpr(in) - case *ColumnDefinition: - return CloneRefOfColumnDefinition(in) - case *ColumnType: - return CloneRefOfColumnType(in) - case Columns: - return CloneColumns(in) - case Comments: - return CloneComments(in) - case *Commit: - return CloneRefOfCommit(in) - case *ComparisonExpr: - return CloneRefOfComparisonExpr(in) - case *ConstraintDefinition: - return CloneRefOfConstraintDefinition(in) - case *ConvertExpr: - return CloneRefOfConvertExpr(in) - case *ConvertType: - return CloneRefOfConvertType(in) - case *ConvertUsingExpr: - return CloneRefOfConvertUsingExpr(in) - case *CreateDatabase: - return CloneRefOfCreateDatabase(in) - case *CreateTable: - return CloneRefOfCreateTable(in) - case *CreateView: - return CloneRefOfCreateView(in) - case *CurTimeFuncExpr: - return CloneRefOfCurTimeFuncExpr(in) - case *Default: - return CloneRefOfDefault(in) - case *Delete: - return CloneRefOfDelete(in) - case *DerivedTable: - return CloneRefOfDerivedTable(in) - case *DropColumn: - return CloneRefOfDropColumn(in) - case *DropDatabase: - return CloneRefOfDropDatabase(in) - case *DropKey: - return CloneRefOfDropKey(in) - case *DropTable: - return CloneRefOfDropTable(in) - case *DropView: - return CloneRefOfDropView(in) - case *ExistsExpr: - return CloneRefOfExistsExpr(in) - case *ExplainStmt: - return CloneRefOfExplainStmt(in) - case *ExplainTab: - return CloneRefOfExplainTab(in) - case Exprs: - return CloneExprs(in) - case *Flush: - return CloneRefOfFlush(in) - case *Force: - return CloneRefOfForce(in) - case *ForeignKeyDefinition: - return CloneRefOfForeignKeyDefinition(in) - case *FuncExpr: - return CloneRefOfFuncExpr(in) - case GroupBy: - return CloneGroupBy(in) - case *GroupConcatExpr: - return CloneRefOfGroupConcatExpr(in) - case *IndexDefinition: - return CloneRefOfIndexDefinition(in) - case *IndexHints: - return CloneRefOfIndexHints(in) - case *IndexInfo: - return CloneRefOfIndexInfo(in) - case *Insert: - return CloneRefOfInsert(in) - case *IntervalExpr: - return CloneRefOfIntervalExpr(in) - case *IsExpr: - return CloneRefOfIsExpr(in) - case IsolationLevel: - return in - case JoinCondition: - return CloneJoinCondition(in) - case *JoinTableExpr: - return CloneRefOfJoinTableExpr(in) - case *KeyState: - return CloneRefOfKeyState(in) - case *Limit: - return CloneRefOfLimit(in) - case ListArg: - return CloneListArg(in) - case *Literal: - return CloneRefOfLiteral(in) - case *Load: - return CloneRefOfLoad(in) - case *LockOption: - return CloneRefOfLockOption(in) - case *LockTables: - return CloneRefOfLockTables(in) - case *MatchExpr: - return CloneRefOfMatchExpr(in) - case *ModifyColumn: - return CloneRefOfModifyColumn(in) - case *Nextval: - return CloneRefOfNextval(in) - case *NotExpr: - return CloneRefOfNotExpr(in) - case *NullVal: - return CloneRefOfNullVal(in) - case OnDup: - return CloneOnDup(in) - case *OptLike: - return CloneRefOfOptLike(in) - case *OrExpr: - return CloneRefOfOrExpr(in) - case *Order: - return CloneRefOfOrder(in) - case OrderBy: - return CloneOrderBy(in) - case *OrderByOption: - return CloneRefOfOrderByOption(in) - case *OtherAdmin: - return CloneRefOfOtherAdmin(in) - case *OtherRead: - return CloneRefOfOtherRead(in) case *ParenSelect: return CloneRefOfParenSelect(in) - case *ParenTableExpr: - return CloneRefOfParenTableExpr(in) - case *PartitionDefinition: - return CloneRefOfPartitionDefinition(in) - case *PartitionSpec: - return CloneRefOfPartitionSpec(in) - case Partitions: - return ClonePartitions(in) - case *RangeCond: - return CloneRefOfRangeCond(in) - case ReferenceAction: - return in - case *Release: - return CloneRefOfRelease(in) - case *RenameIndex: - return CloneRefOfRenameIndex(in) - case *RenameTable: - return CloneRefOfRenameTable(in) - case *RenameTableName: - return CloneRefOfRenameTableName(in) - case *RevertMigration: - return CloneRefOfRevertMigration(in) - case *Rollback: - return CloneRefOfRollback(in) - case *SRollback: - return CloneRefOfSRollback(in) - case *Savepoint: - return CloneRefOfSavepoint(in) case *Select: return CloneRefOfSelect(in) - case SelectExprs: - return CloneSelectExprs(in) - case *SelectInto: - return CloneRefOfSelectInto(in) - case *Set: - return CloneRefOfSet(in) - case *SetExpr: - return CloneRefOfSetExpr(in) - case SetExprs: - return CloneSetExprs(in) - case *SetTransaction: - return CloneRefOfSetTransaction(in) - case *Show: - return CloneRefOfShow(in) - case *ShowBasic: - return CloneRefOfShowBasic(in) - case *ShowCreate: - return CloneRefOfShowCreate(in) - case *ShowFilter: - return CloneRefOfShowFilter(in) - case *ShowLegacy: - return CloneRefOfShowLegacy(in) - case *StarExpr: - return CloneRefOfStarExpr(in) - case *Stream: - return CloneRefOfStream(in) - case *Subquery: - return CloneRefOfSubquery(in) - case *SubstrExpr: - return CloneRefOfSubstrExpr(in) - case TableExprs: - return CloneTableExprs(in) - case TableIdent: - return CloneTableIdent(in) - case TableName: - return CloneTableName(in) - case TableNames: - return CloneTableNames(in) - case TableOptions: - return CloneTableOptions(in) - case *TableSpec: - return CloneRefOfTableSpec(in) - case *TablespaceOperation: - return CloneRefOfTablespaceOperation(in) - case *TimestampFuncExpr: - return CloneRefOfTimestampFuncExpr(in) - case *TruncateTable: - return CloneRefOfTruncateTable(in) - case *UnaryExpr: - return CloneRefOfUnaryExpr(in) case *Union: return CloneRefOfUnion(in) - case *UnionSelect: - return CloneRefOfUnionSelect(in) - case *UnlockTables: - return CloneRefOfUnlockTables(in) - case *Update: - return CloneRefOfUpdate(in) - case *UpdateExpr: - return CloneRefOfUpdateExpr(in) - case UpdateExprs: - return CloneUpdateExprs(in) - case *Use: - return CloneRefOfUse(in) - case *VStream: - return CloneRefOfVStream(in) - case ValTuple: - return CloneValTuple(in) - case *Validation: - return CloneRefOfValidation(in) case Values: return CloneValues(in) - case *ValuesFuncExpr: - return CloneRefOfValuesFuncExpr(in) - case VindexParam: - return CloneVindexParam(in) - case *VindexSpec: - return CloneRefOfVindexSpec(in) - case *When: - return CloneRefOfWhen(in) - case *Where: - return CloneRefOfWhere(in) - case *XorExpr: - return CloneRefOfXorExpr(in) default: // this should never happen return nil } } -// VisitSQLNode will visit all parts of the AST -func VisitSQLNode(in SQLNode, f Visit) error { - if in == nil { - return nil - } - switch in := in.(type) { - case AccessMode: - return VisitAccessMode(in, f) - case *AddColumns: - return VisitRefOfAddColumns(in, f) - case *AddConstraintDefinition: - return VisitRefOfAddConstraintDefinition(in, f) - case *AddIndexDefinition: - return VisitRefOfAddIndexDefinition(in, f) - case AlgorithmValue: - return VisitAlgorithmValue(in, f) - case *AliasedExpr: - return VisitRefOfAliasedExpr(in, f) - case *AliasedTableExpr: - return VisitRefOfAliasedTableExpr(in, f) - case *AlterCharset: - return VisitRefOfAlterCharset(in, f) - case *AlterColumn: - return VisitRefOfAlterColumn(in, f) - case *AlterDatabase: - return VisitRefOfAlterDatabase(in, f) - case *AlterMigration: - return VisitRefOfAlterMigration(in, f) - case *AlterTable: - return VisitRefOfAlterTable(in, f) - case *AlterView: - return VisitRefOfAlterView(in, f) - case *AlterVschema: - return VisitRefOfAlterVschema(in, f) - case *AndExpr: - return VisitRefOfAndExpr(in, f) - case Argument: - return VisitArgument(in, f) - case *AutoIncSpec: - return VisitRefOfAutoIncSpec(in, f) - case *Begin: - return VisitRefOfBegin(in, f) - case *BinaryExpr: - return VisitRefOfBinaryExpr(in, f) - case BoolVal: - return VisitBoolVal(in, f) - case *CallProc: - return VisitRefOfCallProc(in, f) - case *CaseExpr: - return VisitRefOfCaseExpr(in, f) - case *ChangeColumn: - return VisitRefOfChangeColumn(in, f) - case *CheckConstraintDefinition: - return VisitRefOfCheckConstraintDefinition(in, f) - case ColIdent: - return VisitColIdent(in, f) - case *ColName: - return VisitRefOfColName(in, f) - case *CollateExpr: - return VisitRefOfCollateExpr(in, f) - case *ColumnDefinition: - return VisitRefOfColumnDefinition(in, f) - case *ColumnType: - return VisitRefOfColumnType(in, f) - case Columns: - return VisitColumns(in, f) - case Comments: - return VisitComments(in, f) - case *Commit: - return VisitRefOfCommit(in, f) - case *ComparisonExpr: - return VisitRefOfComparisonExpr(in, f) - case *ConstraintDefinition: - return VisitRefOfConstraintDefinition(in, f) - case *ConvertExpr: - return VisitRefOfConvertExpr(in, f) - case *ConvertType: - return VisitRefOfConvertType(in, f) - case *ConvertUsingExpr: - return VisitRefOfConvertUsingExpr(in, f) - case *CreateDatabase: - return VisitRefOfCreateDatabase(in, f) - case *CreateTable: - return VisitRefOfCreateTable(in, f) - case *CreateView: - return VisitRefOfCreateView(in, f) - case *CurTimeFuncExpr: - return VisitRefOfCurTimeFuncExpr(in, f) - case *Default: - return VisitRefOfDefault(in, f) - case *Delete: - return VisitRefOfDelete(in, f) - case *DerivedTable: - return VisitRefOfDerivedTable(in, f) - case *DropColumn: - return VisitRefOfDropColumn(in, f) - case *DropDatabase: - return VisitRefOfDropDatabase(in, f) - case *DropKey: - return VisitRefOfDropKey(in, f) - case *DropTable: - return VisitRefOfDropTable(in, f) - case *DropView: - return VisitRefOfDropView(in, f) - case *ExistsExpr: - return VisitRefOfExistsExpr(in, f) - case *ExplainStmt: - return VisitRefOfExplainStmt(in, f) - case *ExplainTab: - return VisitRefOfExplainTab(in, f) - case Exprs: - return VisitExprs(in, f) - case *Flush: - return VisitRefOfFlush(in, f) - case *Force: - return VisitRefOfForce(in, f) - case *ForeignKeyDefinition: - return VisitRefOfForeignKeyDefinition(in, f) - case *FuncExpr: - return VisitRefOfFuncExpr(in, f) - case GroupBy: - return VisitGroupBy(in, f) - case *GroupConcatExpr: - return VisitRefOfGroupConcatExpr(in, f) - case *IndexDefinition: - return VisitRefOfIndexDefinition(in, f) - case *IndexHints: - return VisitRefOfIndexHints(in, f) - case *IndexInfo: - return VisitRefOfIndexInfo(in, f) - case *Insert: - return VisitRefOfInsert(in, f) - case *IntervalExpr: - return VisitRefOfIntervalExpr(in, f) - case *IsExpr: - return VisitRefOfIsExpr(in, f) - case IsolationLevel: - return VisitIsolationLevel(in, f) - case JoinCondition: - return VisitJoinCondition(in, f) - case *JoinTableExpr: - return VisitRefOfJoinTableExpr(in, f) - case *KeyState: - return VisitRefOfKeyState(in, f) - case *Limit: - return VisitRefOfLimit(in, f) - case ListArg: - return VisitListArg(in, f) - case *Literal: - return VisitRefOfLiteral(in, f) - case *Load: - return VisitRefOfLoad(in, f) - case *LockOption: - return VisitRefOfLockOption(in, f) - case *LockTables: - return VisitRefOfLockTables(in, f) - case *MatchExpr: - return VisitRefOfMatchExpr(in, f) - case *ModifyColumn: - return VisitRefOfModifyColumn(in, f) - case *Nextval: - return VisitRefOfNextval(in, f) - case *NotExpr: - return VisitRefOfNotExpr(in, f) - case *NullVal: - return VisitRefOfNullVal(in, f) - case OnDup: - return VisitOnDup(in, f) - case *OptLike: - return VisitRefOfOptLike(in, f) - case *OrExpr: - return VisitRefOfOrExpr(in, f) - case *Order: - return VisitRefOfOrder(in, f) - case OrderBy: - return VisitOrderBy(in, f) - case *OrderByOption: - return VisitRefOfOrderByOption(in, f) - case *OtherAdmin: - return VisitRefOfOtherAdmin(in, f) - case *OtherRead: - return VisitRefOfOtherRead(in, f) - case *ParenSelect: - return VisitRefOfParenSelect(in, f) - case *ParenTableExpr: - return VisitRefOfParenTableExpr(in, f) - case *PartitionDefinition: - return VisitRefOfPartitionDefinition(in, f) - case *PartitionSpec: - return VisitRefOfPartitionSpec(in, f) - case Partitions: - return VisitPartitions(in, f) - case *RangeCond: - return VisitRefOfRangeCond(in, f) - case ReferenceAction: - return VisitReferenceAction(in, f) - case *Release: - return VisitRefOfRelease(in, f) - case *RenameIndex: - return VisitRefOfRenameIndex(in, f) - case *RenameTable: - return VisitRefOfRenameTable(in, f) - case *RenameTableName: - return VisitRefOfRenameTableName(in, f) - case *RevertMigration: - return VisitRefOfRevertMigration(in, f) - case *Rollback: - return VisitRefOfRollback(in, f) - case *SRollback: - return VisitRefOfSRollback(in, f) - case *Savepoint: - return VisitRefOfSavepoint(in, f) - case *Select: - return VisitRefOfSelect(in, f) - case SelectExprs: - return VisitSelectExprs(in, f) - case *SelectInto: - return VisitRefOfSelectInto(in, f) - case *Set: - return VisitRefOfSet(in, f) - case *SetExpr: - return VisitRefOfSetExpr(in, f) - case SetExprs: - return VisitSetExprs(in, f) - case *SetTransaction: - return VisitRefOfSetTransaction(in, f) - case *Show: - return VisitRefOfShow(in, f) - case *ShowBasic: - return VisitRefOfShowBasic(in, f) - case *ShowCreate: - return VisitRefOfShowCreate(in, f) - case *ShowFilter: - return VisitRefOfShowFilter(in, f) - case *ShowLegacy: - return VisitRefOfShowLegacy(in, f) - case *StarExpr: - return VisitRefOfStarExpr(in, f) - case *Stream: - return VisitRefOfStream(in, f) - case *Subquery: - return VisitRefOfSubquery(in, f) - case *SubstrExpr: - return VisitRefOfSubstrExpr(in, f) - case TableExprs: - return VisitTableExprs(in, f) - case TableIdent: - return VisitTableIdent(in, f) - case TableName: - return VisitTableName(in, f) - case TableNames: - return VisitTableNames(in, f) - case TableOptions: - return VisitTableOptions(in, f) - case *TableSpec: - return VisitRefOfTableSpec(in, f) - case *TablespaceOperation: - return VisitRefOfTablespaceOperation(in, f) - case *TimestampFuncExpr: - return VisitRefOfTimestampFuncExpr(in, f) - case *TruncateTable: - return VisitRefOfTruncateTable(in, f) - case *UnaryExpr: - return VisitRefOfUnaryExpr(in, f) - case *Union: - return VisitRefOfUnion(in, f) - case *UnionSelect: - return VisitRefOfUnionSelect(in, f) - case *UnlockTables: - return VisitRefOfUnlockTables(in, f) - case *Update: - return VisitRefOfUpdate(in, f) - case *UpdateExpr: - return VisitRefOfUpdateExpr(in, f) - case UpdateExprs: - return VisitUpdateExprs(in, f) - case *Use: - return VisitRefOfUse(in, f) - case *VStream: - return VisitRefOfVStream(in, f) - case ValTuple: - return VisitValTuple(in, f) - case *Validation: - return VisitRefOfValidation(in, f) - case Values: - return VisitValues(in, f) - case *ValuesFuncExpr: - return VisitRefOfValuesFuncExpr(in, f) - case VindexParam: - return VisitVindexParam(in, f) - case *VindexSpec: - return VisitRefOfVindexSpec(in, f) - case *When: - return VisitRefOfWhen(in, f) - case *Where: - return VisitRefOfWhere(in, f) - case *XorExpr: - return VisitRefOfXorExpr(in, f) - default: - // this should never happen - return nil - } -} - -// EqualsRefOfAddColumns does deep equals between the two objects. -func EqualsRefOfAddColumns(a, b *AddColumns) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsSliceOfRefOfColumnDefinition(a.Columns, b.Columns) && - EqualsRefOfColName(a.First, b.First) && - EqualsRefOfColName(a.After, b.After) -} - -// CloneRefOfAddColumns creates a deep clone of the input. -func CloneRefOfAddColumns(n *AddColumns) *AddColumns { - if n == nil { +// CloneJoinCondition creates a deep clone of the input. +func CloneJoinCondition(n JoinCondition) JoinCondition { + return *CloneRefOfJoinCondition(&n) +} + +// CloneListArg creates a deep clone of the input. +func CloneListArg(n ListArg) ListArg { + res := make(ListArg, 0, len(n)) + copy(res, n) + return res +} + +// CloneOnDup creates a deep clone of the input. +func CloneOnDup(n OnDup) OnDup { + res := make(OnDup, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfUpdateExpr(x)) + } + return res +} + +// CloneOrderBy creates a deep clone of the input. +func CloneOrderBy(n OrderBy) OrderBy { + res := make(OrderBy, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfOrder(x)) + } + return res +} + +// ClonePartitions creates a deep clone of the input. +func ClonePartitions(n Partitions) Partitions { + res := make(Partitions, 0, len(n)) + for _, x := range n { + res = append(res, CloneColIdent(x)) + } + return res +} + +// CloneRefOfAddColumns creates a deep clone of the input. +func CloneRefOfAddColumns(n *AddColumns) *AddColumns { + if n == nil { return nil } out := *n @@ -1541,202 +378,150 @@ func CloneRefOfAddColumns(n *AddColumns) *AddColumns { return &out } -// VisitRefOfAddColumns will visit all parts of the AST -func VisitRefOfAddColumns(in *AddColumns, f Visit) error { - if in == nil { +// CloneRefOfAddConstraintDefinition creates a deep clone of the input. +func CloneRefOfAddConstraintDefinition(n *AddConstraintDefinition) *AddConstraintDefinition { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - for _, el := range in.Columns { - if err := VisitRefOfColumnDefinition(el, f); err != nil { - return err - } - } - if err := VisitRefOfColName(in.First, f); err != nil { - return err - } - if err := VisitRefOfColName(in.After, f); err != nil { - return err - } - return nil + out := *n + out.ConstraintDefinition = CloneRefOfConstraintDefinition(n.ConstraintDefinition) + return &out } -// EqualsRefOfAddConstraintDefinition does deep equals between the two objects. -func EqualsRefOfAddConstraintDefinition(a, b *AddConstraintDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfAddIndexDefinition creates a deep clone of the input. +func CloneRefOfAddIndexDefinition(n *AddIndexDefinition) *AddIndexDefinition { + if n == nil { + return nil } - return EqualsRefOfConstraintDefinition(a.ConstraintDefinition, b.ConstraintDefinition) + out := *n + out.IndexDefinition = CloneRefOfIndexDefinition(n.IndexDefinition) + return &out } -// CloneRefOfAddConstraintDefinition creates a deep clone of the input. -func CloneRefOfAddConstraintDefinition(n *AddConstraintDefinition) *AddConstraintDefinition { +// CloneRefOfAliasedExpr creates a deep clone of the input. +func CloneRefOfAliasedExpr(n *AliasedExpr) *AliasedExpr { if n == nil { return nil } out := *n - out.ConstraintDefinition = CloneRefOfConstraintDefinition(n.ConstraintDefinition) + out.Expr = CloneExpr(n.Expr) + out.As = CloneColIdent(n.As) return &out } -// VisitRefOfAddConstraintDefinition will visit all parts of the AST -func VisitRefOfAddConstraintDefinition(in *AddConstraintDefinition, f Visit) error { - if in == nil { +// CloneRefOfAliasedTableExpr creates a deep clone of the input. +func CloneRefOfAliasedTableExpr(n *AliasedTableExpr) *AliasedTableExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfConstraintDefinition(in.ConstraintDefinition, f); err != nil { - return err - } - return nil -} - -// EqualsRefOfAddIndexDefinition does deep equals between the two objects. -func EqualsRefOfAddIndexDefinition(a, b *AddIndexDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsRefOfIndexDefinition(a.IndexDefinition, b.IndexDefinition) + out := *n + out.Expr = CloneSimpleTableExpr(n.Expr) + out.Partitions = ClonePartitions(n.Partitions) + out.As = CloneTableIdent(n.As) + out.Hints = CloneRefOfIndexHints(n.Hints) + return &out } -// CloneRefOfAddIndexDefinition creates a deep clone of the input. -func CloneRefOfAddIndexDefinition(n *AddIndexDefinition) *AddIndexDefinition { +// CloneRefOfAlterCharset creates a deep clone of the input. +func CloneRefOfAlterCharset(n *AlterCharset) *AlterCharset { if n == nil { return nil } out := *n - out.IndexDefinition = CloneRefOfIndexDefinition(n.IndexDefinition) return &out } -// VisitRefOfAddIndexDefinition will visit all parts of the AST -func VisitRefOfAddIndexDefinition(in *AddIndexDefinition, f Visit) error { - if in == nil { +// CloneRefOfAlterColumn creates a deep clone of the input. +func CloneRefOfAlterColumn(n *AlterColumn) *AlterColumn { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfIndexDefinition(in.IndexDefinition, f); err != nil { - return err - } - return nil + out := *n + out.Column = CloneRefOfColName(n.Column) + out.DefaultVal = CloneExpr(n.DefaultVal) + return &out } -// EqualsRefOfAliasedExpr does deep equals between the two objects. -func EqualsRefOfAliasedExpr(a, b *AliasedExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfAlterDatabase creates a deep clone of the input. +func CloneRefOfAlterDatabase(n *AlterDatabase) *AlterDatabase { + if n == nil { + return nil } - return EqualsExpr(a.Expr, b.Expr) && - EqualsColIdent(a.As, b.As) + out := *n + out.AlterOptions = CloneSliceOfCollateAndCharset(n.AlterOptions) + return &out } -// CloneRefOfAliasedExpr creates a deep clone of the input. -func CloneRefOfAliasedExpr(n *AliasedExpr) *AliasedExpr { +// CloneRefOfAlterMigration creates a deep clone of the input. +func CloneRefOfAlterMigration(n *AlterMigration) *AlterMigration { if n == nil { return nil } out := *n - out.Expr = CloneExpr(n.Expr) - out.As = CloneColIdent(n.As) return &out } -// VisitRefOfAliasedExpr will visit all parts of the AST -func VisitRefOfAliasedExpr(in *AliasedExpr, f Visit) error { - if in == nil { +// CloneRefOfAlterTable creates a deep clone of the input. +func CloneRefOfAlterTable(n *AlterTable) *AlterTable { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - if err := VisitColIdent(in.As, f); err != nil { - return err - } - return nil + out := *n + out.Table = CloneTableName(n.Table) + out.AlterOptions = CloneSliceOfAlterOption(n.AlterOptions) + out.PartitionSpec = CloneRefOfPartitionSpec(n.PartitionSpec) + return &out } -// EqualsRefOfAliasedTableExpr does deep equals between the two objects. -func EqualsRefOfAliasedTableExpr(a, b *AliasedTableExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfAlterView creates a deep clone of the input. +func CloneRefOfAlterView(n *AlterView) *AlterView { + if n == nil { + return nil } - return EqualsSimpleTableExpr(a.Expr, b.Expr) && - EqualsPartitions(a.Partitions, b.Partitions) && - EqualsTableIdent(a.As, b.As) && - EqualsRefOfIndexHints(a.Hints, b.Hints) + out := *n + out.ViewName = CloneTableName(n.ViewName) + out.Columns = CloneColumns(n.Columns) + out.Select = CloneSelectStatement(n.Select) + return &out } -// CloneRefOfAliasedTableExpr creates a deep clone of the input. -func CloneRefOfAliasedTableExpr(n *AliasedTableExpr) *AliasedTableExpr { +// CloneRefOfAlterVschema creates a deep clone of the input. +func CloneRefOfAlterVschema(n *AlterVschema) *AlterVschema { if n == nil { return nil } out := *n - out.Expr = CloneSimpleTableExpr(n.Expr) - out.Partitions = ClonePartitions(n.Partitions) - out.As = CloneTableIdent(n.As) - out.Hints = CloneRefOfIndexHints(n.Hints) + out.Table = CloneTableName(n.Table) + out.VindexSpec = CloneRefOfVindexSpec(n.VindexSpec) + out.VindexCols = CloneSliceOfColIdent(n.VindexCols) + out.AutoIncSpec = CloneRefOfAutoIncSpec(n.AutoIncSpec) return &out } -// VisitRefOfAliasedTableExpr will visit all parts of the AST -func VisitRefOfAliasedTableExpr(in *AliasedTableExpr, f Visit) error { - if in == nil { +// CloneRefOfAndExpr creates a deep clone of the input. +func CloneRefOfAndExpr(n *AndExpr) *AndExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitSimpleTableExpr(in.Expr, f); err != nil { - return err - } - if err := VisitPartitions(in.Partitions, f); err != nil { - return err - } - if err := VisitTableIdent(in.As, f); err != nil { - return err - } - if err := VisitRefOfIndexHints(in.Hints, f); err != nil { - return err - } - return nil + out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) + return &out } -// EqualsRefOfAlterCharset does deep equals between the two objects. -func EqualsRefOfAlterCharset(a, b *AlterCharset) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfAutoIncSpec creates a deep clone of the input. +func CloneRefOfAutoIncSpec(n *AutoIncSpec) *AutoIncSpec { + if n == nil { + return nil } - return a.CharacterSet == b.CharacterSet && - a.Collate == b.Collate + out := *n + out.Column = CloneColIdent(n.Column) + out.Sequence = CloneTableName(n.Sequence) + return &out } -// CloneRefOfAlterCharset creates a deep clone of the input. -func CloneRefOfAlterCharset(n *AlterCharset) *AlterCharset { +// CloneRefOfBegin creates a deep clone of the input. +func CloneRefOfBegin(n *Begin) *Begin { if n == nil { return nil } @@ -1744,107 +529,88 @@ func CloneRefOfAlterCharset(n *AlterCharset) *AlterCharset { return &out } -// VisitRefOfAlterCharset will visit all parts of the AST -func VisitRefOfAlterCharset(in *AlterCharset, f Visit) error { - if in == nil { +// CloneRefOfBinaryExpr creates a deep clone of the input. +func CloneRefOfBinaryExpr(n *BinaryExpr) *BinaryExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil + out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) + return &out } -// EqualsRefOfAlterColumn does deep equals between the two objects. -func EqualsRefOfAlterColumn(a, b *AlterColumn) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfBool creates a deep clone of the input. +func CloneRefOfBool(n *bool) *bool { + if n == nil { + return nil } - return a.DropDefault == b.DropDefault && - EqualsRefOfColName(a.Column, b.Column) && - EqualsExpr(a.DefaultVal, b.DefaultVal) + out := *n + return &out } -// CloneRefOfAlterColumn creates a deep clone of the input. -func CloneRefOfAlterColumn(n *AlterColumn) *AlterColumn { +// CloneRefOfCallProc creates a deep clone of the input. +func CloneRefOfCallProc(n *CallProc) *CallProc { if n == nil { return nil } out := *n - out.Column = CloneRefOfColName(n.Column) - out.DefaultVal = CloneExpr(n.DefaultVal) + out.Name = CloneTableName(n.Name) + out.Params = CloneExprs(n.Params) return &out } -// VisitRefOfAlterColumn will visit all parts of the AST -func VisitRefOfAlterColumn(in *AlterColumn, f Visit) error { - if in == nil { +// CloneRefOfCaseExpr creates a deep clone of the input. +func CloneRefOfCaseExpr(n *CaseExpr) *CaseExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfColName(in.Column, f); err != nil { - return err - } - if err := VisitExpr(in.DefaultVal, f); err != nil { - return err - } - return nil + out := *n + out.Expr = CloneExpr(n.Expr) + out.Whens = CloneSliceOfRefOfWhen(n.Whens) + out.Else = CloneExpr(n.Else) + return &out } -// EqualsRefOfAlterDatabase does deep equals between the two objects. -func EqualsRefOfAlterDatabase(a, b *AlterDatabase) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfChangeColumn creates a deep clone of the input. +func CloneRefOfChangeColumn(n *ChangeColumn) *ChangeColumn { + if n == nil { + return nil } - return a.DBName == b.DBName && - a.UpdateDataDirectory == b.UpdateDataDirectory && - a.FullyParsed == b.FullyParsed && - EqualsSliceOfCollateAndCharset(a.AlterOptions, b.AlterOptions) + out := *n + out.OldColumn = CloneRefOfColName(n.OldColumn) + out.NewColDefinition = CloneRefOfColumnDefinition(n.NewColDefinition) + out.First = CloneRefOfColName(n.First) + out.After = CloneRefOfColName(n.After) + return &out } -// CloneRefOfAlterDatabase creates a deep clone of the input. -func CloneRefOfAlterDatabase(n *AlterDatabase) *AlterDatabase { +// CloneRefOfCheckConstraintDefinition creates a deep clone of the input. +func CloneRefOfCheckConstraintDefinition(n *CheckConstraintDefinition) *CheckConstraintDefinition { if n == nil { return nil } out := *n - out.AlterOptions = CloneSliceOfCollateAndCharset(n.AlterOptions) + out.Expr = CloneExpr(n.Expr) return &out } -// VisitRefOfAlterDatabase will visit all parts of the AST -func VisitRefOfAlterDatabase(in *AlterDatabase, f Visit) error { - if in == nil { +// CloneRefOfColIdent creates a deep clone of the input. +func CloneRefOfColIdent(n *ColIdent) *ColIdent { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil + out := *n + return &out } -// EqualsRefOfAlterMigration does deep equals between the two objects. -func EqualsRefOfAlterMigration(a, b *AlterMigration) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.UUID == b.UUID && - a.Type == b.Type +// CloneRefOfColName creates a deep clone of the input. +func CloneRefOfColName(n *ColName) *ColName { + return n } -// CloneRefOfAlterMigration creates a deep clone of the input. -func CloneRefOfAlterMigration(n *AlterMigration) *AlterMigration { +// CloneRefOfCollateAndCharset creates a deep clone of the input. +func CloneRefOfCollateAndCharset(n *CollateAndCharset) *CollateAndCharset { if n == nil { return nil } @@ -1852,560 +618,402 @@ func CloneRefOfAlterMigration(n *AlterMigration) *AlterMigration { return &out } -// VisitRefOfAlterMigration will visit all parts of the AST -func VisitRefOfAlterMigration(in *AlterMigration, f Visit) error { - if in == nil { +// CloneRefOfCollateExpr creates a deep clone of the input. +func CloneRefOfCollateExpr(n *CollateExpr) *CollateExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil -} - -// EqualsRefOfAlterTable does deep equals between the two objects. -func EqualsRefOfAlterTable(a, b *AlterTable) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.FullyParsed == b.FullyParsed && - EqualsTableName(a.Table, b.Table) && - EqualsSliceOfAlterOption(a.AlterOptions, b.AlterOptions) && - EqualsRefOfPartitionSpec(a.PartitionSpec, b.PartitionSpec) + out := *n + out.Expr = CloneExpr(n.Expr) + return &out } -// CloneRefOfAlterTable creates a deep clone of the input. -func CloneRefOfAlterTable(n *AlterTable) *AlterTable { +// CloneRefOfColumnDefinition creates a deep clone of the input. +func CloneRefOfColumnDefinition(n *ColumnDefinition) *ColumnDefinition { if n == nil { return nil } out := *n - out.Table = CloneTableName(n.Table) - out.AlterOptions = CloneSliceOfAlterOption(n.AlterOptions) - out.PartitionSpec = CloneRefOfPartitionSpec(n.PartitionSpec) + out.Name = CloneColIdent(n.Name) + out.Type = CloneColumnType(n.Type) return &out } -// VisitRefOfAlterTable will visit all parts of the AST -func VisitRefOfAlterTable(in *AlterTable, f Visit) error { - if in == nil { +// CloneRefOfColumnType creates a deep clone of the input. +func CloneRefOfColumnType(n *ColumnType) *ColumnType { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Table, f); err != nil { - return err - } - for _, el := range in.AlterOptions { - if err := VisitAlterOption(el, f); err != nil { - return err - } - } - if err := VisitRefOfPartitionSpec(in.PartitionSpec, f); err != nil { - return err - } - return nil -} - -// EqualsRefOfAlterView does deep equals between the two objects. -func EqualsRefOfAlterView(a, b *AlterView) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Algorithm == b.Algorithm && - a.Definer == b.Definer && - a.Security == b.Security && - a.CheckOption == b.CheckOption && - EqualsTableName(a.ViewName, b.ViewName) && - EqualsColumns(a.Columns, b.Columns) && - EqualsSelectStatement(a.Select, b.Select) + out := *n + out.Options = CloneRefOfColumnTypeOptions(n.Options) + out.Length = CloneRefOfLiteral(n.Length) + out.Scale = CloneRefOfLiteral(n.Scale) + out.EnumValues = CloneSliceOfString(n.EnumValues) + return &out } -// CloneRefOfAlterView creates a deep clone of the input. -func CloneRefOfAlterView(n *AlterView) *AlterView { +// CloneRefOfColumnTypeOptions creates a deep clone of the input. +func CloneRefOfColumnTypeOptions(n *ColumnTypeOptions) *ColumnTypeOptions { if n == nil { return nil } out := *n - out.ViewName = CloneTableName(n.ViewName) - out.Columns = CloneColumns(n.Columns) - out.Select = CloneSelectStatement(n.Select) + out.Default = CloneExpr(n.Default) + out.OnUpdate = CloneExpr(n.OnUpdate) + out.Comment = CloneRefOfLiteral(n.Comment) return &out } -// VisitRefOfAlterView will visit all parts of the AST -func VisitRefOfAlterView(in *AlterView, f Visit) error { - if in == nil { +// CloneRefOfCommit creates a deep clone of the input. +func CloneRefOfCommit(n *Commit) *Commit { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.ViewName, f); err != nil { - return err - } - if err := VisitColumns(in.Columns, f); err != nil { - return err - } - if err := VisitSelectStatement(in.Select, f); err != nil { - return err - } - return nil + out := *n + return &out } -// EqualsRefOfAlterVschema does deep equals between the two objects. -func EqualsRefOfAlterVschema(a, b *AlterVschema) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfComparisonExpr creates a deep clone of the input. +func CloneRefOfComparisonExpr(n *ComparisonExpr) *ComparisonExpr { + if n == nil { + return nil } - return a.Action == b.Action && - EqualsTableName(a.Table, b.Table) && - EqualsRefOfVindexSpec(a.VindexSpec, b.VindexSpec) && - EqualsSliceOfColIdent(a.VindexCols, b.VindexCols) && - EqualsRefOfAutoIncSpec(a.AutoIncSpec, b.AutoIncSpec) + out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) + out.Escape = CloneExpr(n.Escape) + return &out } -// CloneRefOfAlterVschema creates a deep clone of the input. -func CloneRefOfAlterVschema(n *AlterVschema) *AlterVschema { +// CloneRefOfConstraintDefinition creates a deep clone of the input. +func CloneRefOfConstraintDefinition(n *ConstraintDefinition) *ConstraintDefinition { if n == nil { return nil } out := *n - out.Table = CloneTableName(n.Table) - out.VindexSpec = CloneRefOfVindexSpec(n.VindexSpec) - out.VindexCols = CloneSliceOfColIdent(n.VindexCols) - out.AutoIncSpec = CloneRefOfAutoIncSpec(n.AutoIncSpec) + out.Details = CloneConstraintInfo(n.Details) return &out } -// VisitRefOfAlterVschema will visit all parts of the AST -func VisitRefOfAlterVschema(in *AlterVschema, f Visit) error { - if in == nil { +// CloneRefOfConvertExpr creates a deep clone of the input. +func CloneRefOfConvertExpr(n *ConvertExpr) *ConvertExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Table, f); err != nil { - return err - } - if err := VisitRefOfVindexSpec(in.VindexSpec, f); err != nil { - return err - } - for _, el := range in.VindexCols { - if err := VisitColIdent(el, f); err != nil { - return err - } - } - if err := VisitRefOfAutoIncSpec(in.AutoIncSpec, f); err != nil { - return err - } - return nil + out := *n + out.Expr = CloneExpr(n.Expr) + out.Type = CloneRefOfConvertType(n.Type) + return &out } -// EqualsRefOfAndExpr does deep equals between the two objects. -func EqualsRefOfAndExpr(a, b *AndExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfConvertType creates a deep clone of the input. +func CloneRefOfConvertType(n *ConvertType) *ConvertType { + if n == nil { + return nil } - return EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.Right, b.Right) + out := *n + out.Length = CloneRefOfLiteral(n.Length) + out.Scale = CloneRefOfLiteral(n.Scale) + return &out } -// CloneRefOfAndExpr creates a deep clone of the input. -func CloneRefOfAndExpr(n *AndExpr) *AndExpr { +// CloneRefOfConvertUsingExpr creates a deep clone of the input. +func CloneRefOfConvertUsingExpr(n *ConvertUsingExpr) *ConvertUsingExpr { if n == nil { return nil } out := *n - out.Left = CloneExpr(n.Left) - out.Right = CloneExpr(n.Right) + out.Expr = CloneExpr(n.Expr) return &out } -// VisitRefOfAndExpr will visit all parts of the AST -func VisitRefOfAndExpr(in *AndExpr, f Visit) error { - if in == nil { +// CloneRefOfCreateDatabase creates a deep clone of the input. +func CloneRefOfCreateDatabase(n *CreateDatabase) *CreateDatabase { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Left, f); err != nil { - return err - } - if err := VisitExpr(in.Right, f); err != nil { - return err - } - return nil + out := *n + out.Comments = CloneComments(n.Comments) + out.CreateOptions = CloneSliceOfCollateAndCharset(n.CreateOptions) + return &out } -// EqualsRefOfAutoIncSpec does deep equals between the two objects. -func EqualsRefOfAutoIncSpec(a, b *AutoIncSpec) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfCreateTable creates a deep clone of the input. +func CloneRefOfCreateTable(n *CreateTable) *CreateTable { + if n == nil { + return nil } - return EqualsColIdent(a.Column, b.Column) && - EqualsTableName(a.Sequence, b.Sequence) + out := *n + out.Table = CloneTableName(n.Table) + out.TableSpec = CloneRefOfTableSpec(n.TableSpec) + out.OptLike = CloneRefOfOptLike(n.OptLike) + return &out } -// CloneRefOfAutoIncSpec creates a deep clone of the input. -func CloneRefOfAutoIncSpec(n *AutoIncSpec) *AutoIncSpec { +// CloneRefOfCreateView creates a deep clone of the input. +func CloneRefOfCreateView(n *CreateView) *CreateView { if n == nil { return nil } out := *n - out.Column = CloneColIdent(n.Column) - out.Sequence = CloneTableName(n.Sequence) + out.ViewName = CloneTableName(n.ViewName) + out.Columns = CloneColumns(n.Columns) + out.Select = CloneSelectStatement(n.Select) return &out } -// VisitRefOfAutoIncSpec will visit all parts of the AST -func VisitRefOfAutoIncSpec(in *AutoIncSpec, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColIdent(in.Column, f); err != nil { - return err - } - if err := VisitTableName(in.Sequence, f); err != nil { - return err - } - return nil -} - -// EqualsRefOfBegin does deep equals between the two objects. -func EqualsRefOfBegin(a, b *Begin) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return true -} - -// CloneRefOfBegin creates a deep clone of the input. -func CloneRefOfBegin(n *Begin) *Begin { +// CloneRefOfCurTimeFuncExpr creates a deep clone of the input. +func CloneRefOfCurTimeFuncExpr(n *CurTimeFuncExpr) *CurTimeFuncExpr { if n == nil { return nil } out := *n + out.Name = CloneColIdent(n.Name) + out.Fsp = CloneExpr(n.Fsp) return &out } -// VisitRefOfBegin will visit all parts of the AST -func VisitRefOfBegin(in *Begin, f Visit) error { - if in == nil { +// CloneRefOfDefault creates a deep clone of the input. +func CloneRefOfDefault(n *Default) *Default { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil + out := *n + return &out } -// EqualsRefOfBinaryExpr does deep equals between the two objects. -func EqualsRefOfBinaryExpr(a, b *BinaryExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfDelete creates a deep clone of the input. +func CloneRefOfDelete(n *Delete) *Delete { + if n == nil { + return nil } - return a.Operator == b.Operator && - EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.Right, b.Right) + out := *n + out.Comments = CloneComments(n.Comments) + out.Targets = CloneTableNames(n.Targets) + out.TableExprs = CloneTableExprs(n.TableExprs) + out.Partitions = ClonePartitions(n.Partitions) + out.Where = CloneRefOfWhere(n.Where) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + return &out } -// CloneRefOfBinaryExpr creates a deep clone of the input. -func CloneRefOfBinaryExpr(n *BinaryExpr) *BinaryExpr { +// CloneRefOfDerivedTable creates a deep clone of the input. +func CloneRefOfDerivedTable(n *DerivedTable) *DerivedTable { if n == nil { return nil } out := *n - out.Left = CloneExpr(n.Left) - out.Right = CloneExpr(n.Right) + out.Select = CloneSelectStatement(n.Select) return &out } -// VisitRefOfBinaryExpr will visit all parts of the AST -func VisitRefOfBinaryExpr(in *BinaryExpr, f Visit) error { - if in == nil { +// CloneRefOfDropColumn creates a deep clone of the input. +func CloneRefOfDropColumn(n *DropColumn) *DropColumn { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Left, f); err != nil { - return err - } - if err := VisitExpr(in.Right, f); err != nil { - return err - } - return nil + out := *n + out.Name = CloneRefOfColName(n.Name) + return &out } -// EqualsRefOfCallProc does deep equals between the two objects. -func EqualsRefOfCallProc(a, b *CallProc) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfDropDatabase creates a deep clone of the input. +func CloneRefOfDropDatabase(n *DropDatabase) *DropDatabase { + if n == nil { + return nil } - return EqualsTableName(a.Name, b.Name) && - EqualsExprs(a.Params, b.Params) + out := *n + out.Comments = CloneComments(n.Comments) + return &out } -// CloneRefOfCallProc creates a deep clone of the input. -func CloneRefOfCallProc(n *CallProc) *CallProc { +// CloneRefOfDropKey creates a deep clone of the input. +func CloneRefOfDropKey(n *DropKey) *DropKey { if n == nil { return nil } out := *n - out.Name = CloneTableName(n.Name) - out.Params = CloneExprs(n.Params) return &out } -// VisitRefOfCallProc will visit all parts of the AST -func VisitRefOfCallProc(in *CallProc, f Visit) error { - if in == nil { +// CloneRefOfDropTable creates a deep clone of the input. +func CloneRefOfDropTable(n *DropTable) *DropTable { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Name, f); err != nil { - return err - } - if err := VisitExprs(in.Params, f); err != nil { - return err - } - return nil + out := *n + out.FromTables = CloneTableNames(n.FromTables) + return &out } -// EqualsRefOfCaseExpr does deep equals between the two objects. -func EqualsRefOfCaseExpr(a, b *CaseExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfDropView creates a deep clone of the input. +func CloneRefOfDropView(n *DropView) *DropView { + if n == nil { + return nil } - return EqualsExpr(a.Expr, b.Expr) && - EqualsSliceOfRefOfWhen(a.Whens, b.Whens) && - EqualsExpr(a.Else, b.Else) + out := *n + out.FromTables = CloneTableNames(n.FromTables) + return &out } -// CloneRefOfCaseExpr creates a deep clone of the input. -func CloneRefOfCaseExpr(n *CaseExpr) *CaseExpr { +// CloneRefOfExistsExpr creates a deep clone of the input. +func CloneRefOfExistsExpr(n *ExistsExpr) *ExistsExpr { if n == nil { return nil } out := *n - out.Expr = CloneExpr(n.Expr) - out.Whens = CloneSliceOfRefOfWhen(n.Whens) - out.Else = CloneExpr(n.Else) + out.Subquery = CloneRefOfSubquery(n.Subquery) return &out } -// VisitRefOfCaseExpr will visit all parts of the AST -func VisitRefOfCaseExpr(in *CaseExpr, f Visit) error { - if in == nil { +// CloneRefOfExplainStmt creates a deep clone of the input. +func CloneRefOfExplainStmt(n *ExplainStmt) *ExplainStmt { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - for _, el := range in.Whens { - if err := VisitRefOfWhen(el, f); err != nil { - return err - } - } - if err := VisitExpr(in.Else, f); err != nil { - return err - } - return nil + out := *n + out.Statement = CloneStatement(n.Statement) + return &out } -// EqualsRefOfChangeColumn does deep equals between the two objects. -func EqualsRefOfChangeColumn(a, b *ChangeColumn) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfExplainTab creates a deep clone of the input. +func CloneRefOfExplainTab(n *ExplainTab) *ExplainTab { + if n == nil { + return nil } - return EqualsRefOfColName(a.OldColumn, b.OldColumn) && - EqualsRefOfColumnDefinition(a.NewColDefinition, b.NewColDefinition) && - EqualsRefOfColName(a.First, b.First) && - EqualsRefOfColName(a.After, b.After) + out := *n + out.Table = CloneTableName(n.Table) + return &out } -// CloneRefOfChangeColumn creates a deep clone of the input. -func CloneRefOfChangeColumn(n *ChangeColumn) *ChangeColumn { +// CloneRefOfFlush creates a deep clone of the input. +func CloneRefOfFlush(n *Flush) *Flush { if n == nil { return nil } out := *n - out.OldColumn = CloneRefOfColName(n.OldColumn) - out.NewColDefinition = CloneRefOfColumnDefinition(n.NewColDefinition) - out.First = CloneRefOfColName(n.First) - out.After = CloneRefOfColName(n.After) + out.FlushOptions = CloneSliceOfString(n.FlushOptions) + out.TableNames = CloneTableNames(n.TableNames) return &out } -// VisitRefOfChangeColumn will visit all parts of the AST -func VisitRefOfChangeColumn(in *ChangeColumn, f Visit) error { - if in == nil { +// CloneRefOfForce creates a deep clone of the input. +func CloneRefOfForce(n *Force) *Force { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfColName(in.OldColumn, f); err != nil { - return err - } - if err := VisitRefOfColumnDefinition(in.NewColDefinition, f); err != nil { - return err - } - if err := VisitRefOfColName(in.First, f); err != nil { - return err - } - if err := VisitRefOfColName(in.After, f); err != nil { - return err - } - return nil + out := *n + return &out } -// EqualsRefOfCheckConstraintDefinition does deep equals between the two objects. -func EqualsRefOfCheckConstraintDefinition(a, b *CheckConstraintDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfForeignKeyDefinition creates a deep clone of the input. +func CloneRefOfForeignKeyDefinition(n *ForeignKeyDefinition) *ForeignKeyDefinition { + if n == nil { + return nil } - return a.Enforced == b.Enforced && - EqualsExpr(a.Expr, b.Expr) + out := *n + out.Source = CloneColumns(n.Source) + out.ReferencedTable = CloneTableName(n.ReferencedTable) + out.ReferencedColumns = CloneColumns(n.ReferencedColumns) + return &out } -// CloneRefOfCheckConstraintDefinition creates a deep clone of the input. -func CloneRefOfCheckConstraintDefinition(n *CheckConstraintDefinition) *CheckConstraintDefinition { +// CloneRefOfFuncExpr creates a deep clone of the input. +func CloneRefOfFuncExpr(n *FuncExpr) *FuncExpr { if n == nil { return nil } out := *n - out.Expr = CloneExpr(n.Expr) + out.Qualifier = CloneTableIdent(n.Qualifier) + out.Name = CloneColIdent(n.Name) + out.Exprs = CloneSelectExprs(n.Exprs) return &out } -// VisitRefOfCheckConstraintDefinition will visit all parts of the AST -func VisitRefOfCheckConstraintDefinition(in *CheckConstraintDefinition, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err +// CloneRefOfGroupConcatExpr creates a deep clone of the input. +func CloneRefOfGroupConcatExpr(n *GroupConcatExpr) *GroupConcatExpr { + if n == nil { + return nil } - return nil -} - -// EqualsColIdent does deep equals between the two objects. -func EqualsColIdent(a, b ColIdent) bool { - return a.val == b.val && - a.lowered == b.lowered && - a.at == b.at + out := *n + out.Exprs = CloneSelectExprs(n.Exprs) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + return &out } -// CloneColIdent creates a deep clone of the input. -func CloneColIdent(n ColIdent) ColIdent { - return *CloneRefOfColIdent(&n) +// CloneRefOfIndexColumn creates a deep clone of the input. +func CloneRefOfIndexColumn(n *IndexColumn) *IndexColumn { + if n == nil { + return nil + } + out := *n + out.Column = CloneColIdent(n.Column) + out.Length = CloneRefOfLiteral(n.Length) + return &out } -// VisitColIdent will visit all parts of the AST -func VisitColIdent(in ColIdent, f Visit) error { - if cont, err := f(in); err != nil || !cont { - return err +// CloneRefOfIndexDefinition creates a deep clone of the input. +func CloneRefOfIndexDefinition(n *IndexDefinition) *IndexDefinition { + if n == nil { + return nil } - return nil + out := *n + out.Info = CloneRefOfIndexInfo(n.Info) + out.Columns = CloneSliceOfRefOfIndexColumn(n.Columns) + out.Options = CloneSliceOfRefOfIndexOption(n.Options) + return &out } -// EqualsRefOfColName does deep equals between the two objects. -func EqualsRefOfColName(a, b *ColName) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfIndexHints creates a deep clone of the input. +func CloneRefOfIndexHints(n *IndexHints) *IndexHints { + if n == nil { + return nil } - return EqualsColIdent(a.Name, b.Name) && - EqualsTableName(a.Qualifier, b.Qualifier) + out := *n + out.Indexes = CloneSliceOfColIdent(n.Indexes) + return &out } -// CloneRefOfColName creates a deep clone of the input. -func CloneRefOfColName(n *ColName) *ColName { - return n +// CloneRefOfIndexInfo creates a deep clone of the input. +func CloneRefOfIndexInfo(n *IndexInfo) *IndexInfo { + if n == nil { + return nil + } + out := *n + out.Name = CloneColIdent(n.Name) + out.ConstraintName = CloneColIdent(n.ConstraintName) + return &out } -// VisitRefOfColName will visit all parts of the AST -func VisitRefOfColName(in *ColName, f Visit) error { - if in == nil { +// CloneRefOfIndexOption creates a deep clone of the input. +func CloneRefOfIndexOption(n *IndexOption) *IndexOption { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColIdent(in.Name, f); err != nil { - return err - } - if err := VisitTableName(in.Qualifier, f); err != nil { - return err - } - return nil + out := *n + out.Value = CloneRefOfLiteral(n.Value) + return &out } -// EqualsRefOfCollateExpr does deep equals between the two objects. -func EqualsRefOfCollateExpr(a, b *CollateExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfInsert creates a deep clone of the input. +func CloneRefOfInsert(n *Insert) *Insert { + if n == nil { + return nil } - return a.Charset == b.Charset && - EqualsExpr(a.Expr, b.Expr) + out := *n + out.Comments = CloneComments(n.Comments) + out.Table = CloneTableName(n.Table) + out.Partitions = ClonePartitions(n.Partitions) + out.Columns = CloneColumns(n.Columns) + out.Rows = CloneInsertRows(n.Rows) + out.OnDup = CloneOnDup(n.OnDup) + return &out } -// CloneRefOfCollateExpr creates a deep clone of the input. -func CloneRefOfCollateExpr(n *CollateExpr) *CollateExpr { +// CloneRefOfIntervalExpr creates a deep clone of the input. +func CloneRefOfIntervalExpr(n *IntervalExpr) *IntervalExpr { if n == nil { return nil } @@ -2414,690 +1022,501 @@ func CloneRefOfCollateExpr(n *CollateExpr) *CollateExpr { return &out } -// VisitRefOfCollateExpr will visit all parts of the AST -func VisitRefOfCollateExpr(in *CollateExpr, f Visit) error { - if in == nil { +// CloneRefOfIsExpr creates a deep clone of the input. +func CloneRefOfIsExpr(n *IsExpr) *IsExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - return nil + out := *n + out.Expr = CloneExpr(n.Expr) + return &out } -// EqualsRefOfColumnDefinition does deep equals between the two objects. -func EqualsRefOfColumnDefinition(a, b *ColumnDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfJoinCondition creates a deep clone of the input. +func CloneRefOfJoinCondition(n *JoinCondition) *JoinCondition { + if n == nil { + return nil } - return EqualsColIdent(a.Name, b.Name) && - EqualsColumnType(a.Type, b.Type) + out := *n + out.On = CloneExpr(n.On) + out.Using = CloneColumns(n.Using) + return &out } -// CloneRefOfColumnDefinition creates a deep clone of the input. -func CloneRefOfColumnDefinition(n *ColumnDefinition) *ColumnDefinition { +// CloneRefOfJoinTableExpr creates a deep clone of the input. +func CloneRefOfJoinTableExpr(n *JoinTableExpr) *JoinTableExpr { if n == nil { return nil } out := *n - out.Name = CloneColIdent(n.Name) - out.Type = CloneColumnType(n.Type) + out.LeftExpr = CloneTableExpr(n.LeftExpr) + out.RightExpr = CloneTableExpr(n.RightExpr) + out.Condition = CloneJoinCondition(n.Condition) return &out } -// VisitRefOfColumnDefinition will visit all parts of the AST -func VisitRefOfColumnDefinition(in *ColumnDefinition, f Visit) error { - if in == nil { +// CloneRefOfKeyState creates a deep clone of the input. +func CloneRefOfKeyState(n *KeyState) *KeyState { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColIdent(in.Name, f); err != nil { - return err - } - return nil + out := *n + return &out } -// EqualsRefOfColumnType does deep equals between the two objects. -func EqualsRefOfColumnType(a, b *ColumnType) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfLimit creates a deep clone of the input. +func CloneRefOfLimit(n *Limit) *Limit { + if n == nil { + return nil } - return a.Type == b.Type && - a.Unsigned == b.Unsigned && - a.Zerofill == b.Zerofill && - a.Charset == b.Charset && - a.Collate == b.Collate && - EqualsRefOfColumnTypeOptions(a.Options, b.Options) && - EqualsRefOfLiteral(a.Length, b.Length) && - EqualsRefOfLiteral(a.Scale, b.Scale) && - EqualsSliceOfString(a.EnumValues, b.EnumValues) + out := *n + out.Offset = CloneExpr(n.Offset) + out.Rowcount = CloneExpr(n.Rowcount) + return &out } -// CloneRefOfColumnType creates a deep clone of the input. -func CloneRefOfColumnType(n *ColumnType) *ColumnType { +// CloneRefOfLiteral creates a deep clone of the input. +func CloneRefOfLiteral(n *Literal) *Literal { if n == nil { return nil } out := *n - out.Options = CloneRefOfColumnTypeOptions(n.Options) - out.Length = CloneRefOfLiteral(n.Length) - out.Scale = CloneRefOfLiteral(n.Scale) - out.EnumValues = CloneSliceOfString(n.EnumValues) return &out } -// VisitRefOfColumnType will visit all parts of the AST -func VisitRefOfColumnType(in *ColumnType, f Visit) error { - if in == nil { +// CloneRefOfLoad creates a deep clone of the input. +func CloneRefOfLoad(n *Load) *Load { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfLiteral(in.Length, f); err != nil { - return err - } - if err := VisitRefOfLiteral(in.Scale, f); err != nil { - return err - } - return nil + out := *n + return &out } -// EqualsColumns does deep equals between the two objects. -func EqualsColumns(a, b Columns) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsColIdent(a[i], b[i]) { - return false - } +// CloneRefOfLockOption creates a deep clone of the input. +func CloneRefOfLockOption(n *LockOption) *LockOption { + if n == nil { + return nil } - return true + out := *n + return &out } -// CloneColumns creates a deep clone of the input. -func CloneColumns(n Columns) Columns { - res := make(Columns, 0, len(n)) - for _, x := range n { - res = append(res, CloneColIdent(x)) +// CloneRefOfLockTables creates a deep clone of the input. +func CloneRefOfLockTables(n *LockTables) *LockTables { + if n == nil { + return nil } - return res + out := *n + out.Tables = CloneTableAndLockTypes(n.Tables) + return &out } -// VisitColumns will visit all parts of the AST -func VisitColumns(in Columns, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - for _, el := range in { - if err := VisitColIdent(el, f); err != nil { - return err - } +// CloneRefOfMatchExpr creates a deep clone of the input. +func CloneRefOfMatchExpr(n *MatchExpr) *MatchExpr { + if n == nil { + return nil } - return nil + out := *n + out.Columns = CloneSelectExprs(n.Columns) + out.Expr = CloneExpr(n.Expr) + return &out } -// EqualsComments does deep equals between the two objects. -func EqualsComments(a, b Comments) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false - } +// CloneRefOfModifyColumn creates a deep clone of the input. +func CloneRefOfModifyColumn(n *ModifyColumn) *ModifyColumn { + if n == nil { + return nil } - return true -} - -// CloneComments creates a deep clone of the input. -func CloneComments(n Comments) Comments { - res := make(Comments, 0, len(n)) - copy(res, n) - return res -} - -// VisitComments will visit all parts of the AST -func VisitComments(in Comments, f Visit) error { - _, err := f(in) - return err + out := *n + out.NewColDefinition = CloneRefOfColumnDefinition(n.NewColDefinition) + out.First = CloneRefOfColName(n.First) + out.After = CloneRefOfColName(n.After) + return &out } -// EqualsRefOfCommit does deep equals between the two objects. -func EqualsRefOfCommit(a, b *Commit) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfNextval creates a deep clone of the input. +func CloneRefOfNextval(n *Nextval) *Nextval { + if n == nil { + return nil } - return true + out := *n + out.Expr = CloneExpr(n.Expr) + return &out } -// CloneRefOfCommit creates a deep clone of the input. -func CloneRefOfCommit(n *Commit) *Commit { +// CloneRefOfNotExpr creates a deep clone of the input. +func CloneRefOfNotExpr(n *NotExpr) *NotExpr { if n == nil { return nil } out := *n + out.Expr = CloneExpr(n.Expr) return &out } -// VisitRefOfCommit will visit all parts of the AST -func VisitRefOfCommit(in *Commit, f Visit) error { - if in == nil { +// CloneRefOfNullVal creates a deep clone of the input. +func CloneRefOfNullVal(n *NullVal) *NullVal { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil + out := *n + return &out } -// EqualsRefOfComparisonExpr does deep equals between the two objects. -func EqualsRefOfComparisonExpr(a, b *ComparisonExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfOptLike creates a deep clone of the input. +func CloneRefOfOptLike(n *OptLike) *OptLike { + if n == nil { + return nil } - return a.Operator == b.Operator && - EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.Right, b.Right) && - EqualsExpr(a.Escape, b.Escape) + out := *n + out.LikeTable = CloneTableName(n.LikeTable) + return &out } -// CloneRefOfComparisonExpr creates a deep clone of the input. -func CloneRefOfComparisonExpr(n *ComparisonExpr) *ComparisonExpr { +// CloneRefOfOrExpr creates a deep clone of the input. +func CloneRefOfOrExpr(n *OrExpr) *OrExpr { if n == nil { return nil } out := *n out.Left = CloneExpr(n.Left) out.Right = CloneExpr(n.Right) - out.Escape = CloneExpr(n.Escape) return &out } -// VisitRefOfComparisonExpr will visit all parts of the AST -func VisitRefOfComparisonExpr(in *ComparisonExpr, f Visit) error { - if in == nil { +// CloneRefOfOrder creates a deep clone of the input. +func CloneRefOfOrder(n *Order) *Order { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Left, f); err != nil { - return err - } - if err := VisitExpr(in.Right, f); err != nil { - return err - } - if err := VisitExpr(in.Escape, f); err != nil { - return err - } - return nil + out := *n + out.Expr = CloneExpr(n.Expr) + return &out } -// EqualsRefOfConstraintDefinition does deep equals between the two objects. -func EqualsRefOfConstraintDefinition(a, b *ConstraintDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfOrderByOption creates a deep clone of the input. +func CloneRefOfOrderByOption(n *OrderByOption) *OrderByOption { + if n == nil { + return nil } - return a.Name == b.Name && - EqualsConstraintInfo(a.Details, b.Details) + out := *n + out.Cols = CloneColumns(n.Cols) + return &out } -// CloneRefOfConstraintDefinition creates a deep clone of the input. -func CloneRefOfConstraintDefinition(n *ConstraintDefinition) *ConstraintDefinition { +// CloneRefOfOtherAdmin creates a deep clone of the input. +func CloneRefOfOtherAdmin(n *OtherAdmin) *OtherAdmin { if n == nil { return nil } out := *n - out.Details = CloneConstraintInfo(n.Details) return &out } -// VisitRefOfConstraintDefinition will visit all parts of the AST -func VisitRefOfConstraintDefinition(in *ConstraintDefinition, f Visit) error { - if in == nil { +// CloneRefOfOtherRead creates a deep clone of the input. +func CloneRefOfOtherRead(n *OtherRead) *OtherRead { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitConstraintInfo(in.Details, f); err != nil { - return err - } - return nil + out := *n + return &out } -// EqualsRefOfConvertExpr does deep equals between the two objects. -func EqualsRefOfConvertExpr(a, b *ConvertExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfParenSelect creates a deep clone of the input. +func CloneRefOfParenSelect(n *ParenSelect) *ParenSelect { + if n == nil { + return nil } - return EqualsExpr(a.Expr, b.Expr) && - EqualsRefOfConvertType(a.Type, b.Type) + out := *n + out.Select = CloneSelectStatement(n.Select) + return &out } -// CloneRefOfConvertExpr creates a deep clone of the input. -func CloneRefOfConvertExpr(n *ConvertExpr) *ConvertExpr { +// CloneRefOfParenTableExpr creates a deep clone of the input. +func CloneRefOfParenTableExpr(n *ParenTableExpr) *ParenTableExpr { if n == nil { return nil } out := *n - out.Expr = CloneExpr(n.Expr) - out.Type = CloneRefOfConvertType(n.Type) + out.Exprs = CloneTableExprs(n.Exprs) return &out } -// VisitRefOfConvertExpr will visit all parts of the AST -func VisitRefOfConvertExpr(in *ConvertExpr, f Visit) error { - if in == nil { +// CloneRefOfPartitionDefinition creates a deep clone of the input. +func CloneRefOfPartitionDefinition(n *PartitionDefinition) *PartitionDefinition { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - if err := VisitRefOfConvertType(in.Type, f); err != nil { - return err - } - return nil + out := *n + out.Name = CloneColIdent(n.Name) + out.Limit = CloneExpr(n.Limit) + return &out } -// EqualsRefOfConvertType does deep equals between the two objects. -func EqualsRefOfConvertType(a, b *ConvertType) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfPartitionSpec creates a deep clone of the input. +func CloneRefOfPartitionSpec(n *PartitionSpec) *PartitionSpec { + if n == nil { + return nil } - return a.Type == b.Type && - a.Charset == b.Charset && - EqualsRefOfLiteral(a.Length, b.Length) && - EqualsRefOfLiteral(a.Scale, b.Scale) && - a.Operator == b.Operator + out := *n + out.Names = ClonePartitions(n.Names) + out.Number = CloneRefOfLiteral(n.Number) + out.TableName = CloneTableName(n.TableName) + out.Definitions = CloneSliceOfRefOfPartitionDefinition(n.Definitions) + return &out } -// CloneRefOfConvertType creates a deep clone of the input. -func CloneRefOfConvertType(n *ConvertType) *ConvertType { +// CloneRefOfRangeCond creates a deep clone of the input. +func CloneRefOfRangeCond(n *RangeCond) *RangeCond { if n == nil { return nil } out := *n - out.Length = CloneRefOfLiteral(n.Length) - out.Scale = CloneRefOfLiteral(n.Scale) + out.Left = CloneExpr(n.Left) + out.From = CloneExpr(n.From) + out.To = CloneExpr(n.To) return &out } -// VisitRefOfConvertType will visit all parts of the AST -func VisitRefOfConvertType(in *ConvertType, f Visit) error { - if in == nil { +// CloneRefOfRelease creates a deep clone of the input. +func CloneRefOfRelease(n *Release) *Release { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfLiteral(in.Length, f); err != nil { - return err - } - if err := VisitRefOfLiteral(in.Scale, f); err != nil { - return err - } - return nil + out := *n + out.Name = CloneColIdent(n.Name) + return &out } -// EqualsRefOfConvertUsingExpr does deep equals between the two objects. -func EqualsRefOfConvertUsingExpr(a, b *ConvertUsingExpr) bool { - if a == b { - return true +// CloneRefOfRenameIndex creates a deep clone of the input. +func CloneRefOfRenameIndex(n *RenameIndex) *RenameIndex { + if n == nil { + return nil } - if a == nil || b == nil { - return false - } - return a.Type == b.Type && - EqualsExpr(a.Expr, b.Expr) + out := *n + return &out } -// CloneRefOfConvertUsingExpr creates a deep clone of the input. -func CloneRefOfConvertUsingExpr(n *ConvertUsingExpr) *ConvertUsingExpr { +// CloneRefOfRenameTable creates a deep clone of the input. +func CloneRefOfRenameTable(n *RenameTable) *RenameTable { if n == nil { return nil } out := *n - out.Expr = CloneExpr(n.Expr) + out.TablePairs = CloneSliceOfRefOfRenameTablePair(n.TablePairs) return &out } -// VisitRefOfConvertUsingExpr will visit all parts of the AST -func VisitRefOfConvertUsingExpr(in *ConvertUsingExpr, f Visit) error { - if in == nil { +// CloneRefOfRenameTableName creates a deep clone of the input. +func CloneRefOfRenameTableName(n *RenameTableName) *RenameTableName { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - return nil + out := *n + out.Table = CloneTableName(n.Table) + return &out } -// EqualsRefOfCreateDatabase does deep equals between the two objects. -func EqualsRefOfCreateDatabase(a, b *CreateDatabase) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfRenameTablePair creates a deep clone of the input. +func CloneRefOfRenameTablePair(n *RenameTablePair) *RenameTablePair { + if n == nil { + return nil } - return a.DBName == b.DBName && - a.IfNotExists == b.IfNotExists && - a.FullyParsed == b.FullyParsed && - EqualsComments(a.Comments, b.Comments) && - EqualsSliceOfCollateAndCharset(a.CreateOptions, b.CreateOptions) + out := *n + out.FromTable = CloneTableName(n.FromTable) + out.ToTable = CloneTableName(n.ToTable) + return &out } -// CloneRefOfCreateDatabase creates a deep clone of the input. -func CloneRefOfCreateDatabase(n *CreateDatabase) *CreateDatabase { +// CloneRefOfRevertMigration creates a deep clone of the input. +func CloneRefOfRevertMigration(n *RevertMigration) *RevertMigration { if n == nil { return nil } out := *n - out.Comments = CloneComments(n.Comments) - out.CreateOptions = CloneSliceOfCollateAndCharset(n.CreateOptions) return &out } -// VisitRefOfCreateDatabase will visit all parts of the AST -func VisitRefOfCreateDatabase(in *CreateDatabase, f Visit) error { - if in == nil { +// CloneRefOfRollback creates a deep clone of the input. +func CloneRefOfRollback(n *Rollback) *Rollback { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitComments(in.Comments, f); err != nil { - return err - } - return nil + out := *n + return &out } -// EqualsRefOfCreateTable does deep equals between the two objects. -func EqualsRefOfCreateTable(a, b *CreateTable) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfSRollback creates a deep clone of the input. +func CloneRefOfSRollback(n *SRollback) *SRollback { + if n == nil { + return nil } - return a.Temp == b.Temp && - a.IfNotExists == b.IfNotExists && - a.FullyParsed == b.FullyParsed && - EqualsTableName(a.Table, b.Table) && - EqualsRefOfTableSpec(a.TableSpec, b.TableSpec) && - EqualsRefOfOptLike(a.OptLike, b.OptLike) + out := *n + out.Name = CloneColIdent(n.Name) + return &out } -// CloneRefOfCreateTable creates a deep clone of the input. -func CloneRefOfCreateTable(n *CreateTable) *CreateTable { +// CloneRefOfSavepoint creates a deep clone of the input. +func CloneRefOfSavepoint(n *Savepoint) *Savepoint { if n == nil { return nil } out := *n - out.Table = CloneTableName(n.Table) - out.TableSpec = CloneRefOfTableSpec(n.TableSpec) - out.OptLike = CloneRefOfOptLike(n.OptLike) + out.Name = CloneColIdent(n.Name) return &out } -// VisitRefOfCreateTable will visit all parts of the AST -func VisitRefOfCreateTable(in *CreateTable, f Visit) error { - if in == nil { +// CloneRefOfSelect creates a deep clone of the input. +func CloneRefOfSelect(n *Select) *Select { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Table, f); err != nil { - return err - } - if err := VisitRefOfTableSpec(in.TableSpec, f); err != nil { - return err - } - if err := VisitRefOfOptLike(in.OptLike, f); err != nil { - return err - } - return nil + out := *n + out.Cache = CloneRefOfBool(n.Cache) + out.Comments = CloneComments(n.Comments) + out.SelectExprs = CloneSelectExprs(n.SelectExprs) + out.From = CloneTableExprs(n.From) + out.Where = CloneRefOfWhere(n.Where) + out.GroupBy = CloneGroupBy(n.GroupBy) + out.Having = CloneRefOfWhere(n.Having) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + out.Into = CloneRefOfSelectInto(n.Into) + return &out } -// EqualsRefOfCreateView does deep equals between the two objects. -func EqualsRefOfCreateView(a, b *CreateView) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfSelectInto creates a deep clone of the input. +func CloneRefOfSelectInto(n *SelectInto) *SelectInto { + if n == nil { + return nil } - return a.Algorithm == b.Algorithm && - a.Definer == b.Definer && - a.Security == b.Security && - a.CheckOption == b.CheckOption && - a.IsReplace == b.IsReplace && - EqualsTableName(a.ViewName, b.ViewName) && - EqualsColumns(a.Columns, b.Columns) && - EqualsSelectStatement(a.Select, b.Select) + out := *n + return &out } -// CloneRefOfCreateView creates a deep clone of the input. -func CloneRefOfCreateView(n *CreateView) *CreateView { +// CloneRefOfSet creates a deep clone of the input. +func CloneRefOfSet(n *Set) *Set { if n == nil { return nil } out := *n - out.ViewName = CloneTableName(n.ViewName) - out.Columns = CloneColumns(n.Columns) - out.Select = CloneSelectStatement(n.Select) + out.Comments = CloneComments(n.Comments) + out.Exprs = CloneSetExprs(n.Exprs) return &out } -// VisitRefOfCreateView will visit all parts of the AST -func VisitRefOfCreateView(in *CreateView, f Visit) error { - if in == nil { +// CloneRefOfSetExpr creates a deep clone of the input. +func CloneRefOfSetExpr(n *SetExpr) *SetExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.ViewName, f); err != nil { - return err - } - if err := VisitColumns(in.Columns, f); err != nil { - return err - } - if err := VisitSelectStatement(in.Select, f); err != nil { - return err - } - return nil + out := *n + out.Name = CloneColIdent(n.Name) + out.Expr = CloneExpr(n.Expr) + return &out } -// EqualsRefOfCurTimeFuncExpr does deep equals between the two objects. -func EqualsRefOfCurTimeFuncExpr(a, b *CurTimeFuncExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfSetTransaction creates a deep clone of the input. +func CloneRefOfSetTransaction(n *SetTransaction) *SetTransaction { + if n == nil { + return nil } - return EqualsColIdent(a.Name, b.Name) && - EqualsExpr(a.Fsp, b.Fsp) + out := *n + out.SQLNode = CloneSQLNode(n.SQLNode) + out.Comments = CloneComments(n.Comments) + out.Characteristics = CloneSliceOfCharacteristic(n.Characteristics) + return &out } -// CloneRefOfCurTimeFuncExpr creates a deep clone of the input. -func CloneRefOfCurTimeFuncExpr(n *CurTimeFuncExpr) *CurTimeFuncExpr { +// CloneRefOfShow creates a deep clone of the input. +func CloneRefOfShow(n *Show) *Show { if n == nil { return nil } out := *n - out.Name = CloneColIdent(n.Name) - out.Fsp = CloneExpr(n.Fsp) + out.Internal = CloneShowInternal(n.Internal) return &out } -// VisitRefOfCurTimeFuncExpr will visit all parts of the AST -func VisitRefOfCurTimeFuncExpr(in *CurTimeFuncExpr, f Visit) error { - if in == nil { +// CloneRefOfShowBasic creates a deep clone of the input. +func CloneRefOfShowBasic(n *ShowBasic) *ShowBasic { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColIdent(in.Name, f); err != nil { - return err - } - if err := VisitExpr(in.Fsp, f); err != nil { - return err - } - return nil + out := *n + out.Tbl = CloneTableName(n.Tbl) + out.Filter = CloneRefOfShowFilter(n.Filter) + return &out } -// EqualsRefOfDefault does deep equals between the two objects. -func EqualsRefOfDefault(a, b *Default) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfShowCreate creates a deep clone of the input. +func CloneRefOfShowCreate(n *ShowCreate) *ShowCreate { + if n == nil { + return nil } - return a.ColName == b.ColName + out := *n + out.Op = CloneTableName(n.Op) + return &out } -// CloneRefOfDefault creates a deep clone of the input. -func CloneRefOfDefault(n *Default) *Default { +// CloneRefOfShowFilter creates a deep clone of the input. +func CloneRefOfShowFilter(n *ShowFilter) *ShowFilter { if n == nil { return nil } out := *n + out.Filter = CloneExpr(n.Filter) return &out } -// VisitRefOfDefault will visit all parts of the AST -func VisitRefOfDefault(in *Default, f Visit) error { - if in == nil { +// CloneRefOfShowLegacy creates a deep clone of the input. +func CloneRefOfShowLegacy(n *ShowLegacy) *ShowLegacy { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil -} - -// EqualsRefOfDelete does deep equals between the two objects. -func EqualsRefOfDelete(a, b *Delete) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Ignore == b.Ignore && - EqualsComments(a.Comments, b.Comments) && - EqualsTableNames(a.Targets, b.Targets) && - EqualsTableExprs(a.TableExprs, b.TableExprs) && - EqualsPartitions(a.Partitions, b.Partitions) && - EqualsRefOfWhere(a.Where, b.Where) && - EqualsOrderBy(a.OrderBy, b.OrderBy) && - EqualsRefOfLimit(a.Limit, b.Limit) + out := *n + out.OnTable = CloneTableName(n.OnTable) + out.Table = CloneTableName(n.Table) + out.ShowTablesOpt = CloneRefOfShowTablesOpt(n.ShowTablesOpt) + out.ShowCollationFilterOpt = CloneExpr(n.ShowCollationFilterOpt) + return &out } -// CloneRefOfDelete creates a deep clone of the input. -func CloneRefOfDelete(n *Delete) *Delete { +// CloneRefOfShowTablesOpt creates a deep clone of the input. +func CloneRefOfShowTablesOpt(n *ShowTablesOpt) *ShowTablesOpt { if n == nil { return nil } out := *n - out.Comments = CloneComments(n.Comments) - out.Targets = CloneTableNames(n.Targets) - out.TableExprs = CloneTableExprs(n.TableExprs) - out.Partitions = ClonePartitions(n.Partitions) - out.Where = CloneRefOfWhere(n.Where) - out.OrderBy = CloneOrderBy(n.OrderBy) - out.Limit = CloneRefOfLimit(n.Limit) + out.Filter = CloneRefOfShowFilter(n.Filter) return &out } -// VisitRefOfDelete will visit all parts of the AST -func VisitRefOfDelete(in *Delete, f Visit) error { - if in == nil { +// CloneRefOfStarExpr creates a deep clone of the input. +func CloneRefOfStarExpr(n *StarExpr) *StarExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitComments(in.Comments, f); err != nil { - return err - } - if err := VisitTableNames(in.Targets, f); err != nil { - return err - } - if err := VisitTableExprs(in.TableExprs, f); err != nil { - return err - } - if err := VisitPartitions(in.Partitions, f); err != nil { - return err - } - if err := VisitRefOfWhere(in.Where, f); err != nil { - return err - } - if err := VisitOrderBy(in.OrderBy, f); err != nil { - return err - } - if err := VisitRefOfLimit(in.Limit, f); err != nil { - return err - } - return nil + out := *n + out.TableName = CloneTableName(n.TableName) + return &out } -// EqualsRefOfDerivedTable does deep equals between the two objects. -func EqualsRefOfDerivedTable(a, b *DerivedTable) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfStream creates a deep clone of the input. +func CloneRefOfStream(n *Stream) *Stream { + if n == nil { + return nil } - return EqualsSelectStatement(a.Select, b.Select) + out := *n + out.Comments = CloneComments(n.Comments) + out.SelectExpr = CloneSelectExpr(n.SelectExpr) + out.Table = CloneTableName(n.Table) + return &out } -// CloneRefOfDerivedTable creates a deep clone of the input. -func CloneRefOfDerivedTable(n *DerivedTable) *DerivedTable { +// CloneRefOfSubquery creates a deep clone of the input. +func CloneRefOfSubquery(n *Subquery) *Subquery { if n == nil { return nil } @@ -3106,865 +1525,5350 @@ func CloneRefOfDerivedTable(n *DerivedTable) *DerivedTable { return &out } -// VisitRefOfDerivedTable will visit all parts of the AST -func VisitRefOfDerivedTable(in *DerivedTable, f Visit) error { - if in == nil { +// CloneRefOfSubstrExpr creates a deep clone of the input. +func CloneRefOfSubstrExpr(n *SubstrExpr) *SubstrExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitSelectStatement(in.Select, f); err != nil { - return err - } - return nil + out := *n + out.Name = CloneRefOfColName(n.Name) + out.StrVal = CloneRefOfLiteral(n.StrVal) + out.From = CloneExpr(n.From) + out.To = CloneExpr(n.To) + return &out } -// EqualsRefOfDropColumn does deep equals between the two objects. -func EqualsRefOfDropColumn(a, b *DropColumn) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfTableAndLockType creates a deep clone of the input. +func CloneRefOfTableAndLockType(n *TableAndLockType) *TableAndLockType { + if n == nil { + return nil } - return EqualsRefOfColName(a.Name, b.Name) + out := *n + out.Table = CloneTableExpr(n.Table) + return &out } -// CloneRefOfDropColumn creates a deep clone of the input. -func CloneRefOfDropColumn(n *DropColumn) *DropColumn { +// CloneRefOfTableIdent creates a deep clone of the input. +func CloneRefOfTableIdent(n *TableIdent) *TableIdent { if n == nil { return nil } out := *n - out.Name = CloneRefOfColName(n.Name) return &out } -// VisitRefOfDropColumn will visit all parts of the AST -func VisitRefOfDropColumn(in *DropColumn, f Visit) error { - if in == nil { +// CloneRefOfTableName creates a deep clone of the input. +func CloneRefOfTableName(n *TableName) *TableName { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfColName(in.Name, f); err != nil { - return err - } - return nil + out := *n + out.Name = CloneTableIdent(n.Name) + out.Qualifier = CloneTableIdent(n.Qualifier) + return &out } -// EqualsRefOfDropDatabase does deep equals between the two objects. -func EqualsRefOfDropDatabase(a, b *DropDatabase) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfTableOption creates a deep clone of the input. +func CloneRefOfTableOption(n *TableOption) *TableOption { + if n == nil { + return nil } - return a.DBName == b.DBName && - a.IfExists == b.IfExists && - EqualsComments(a.Comments, b.Comments) + out := *n + out.Value = CloneRefOfLiteral(n.Value) + out.Tables = CloneTableNames(n.Tables) + return &out } -// CloneRefOfDropDatabase creates a deep clone of the input. -func CloneRefOfDropDatabase(n *DropDatabase) *DropDatabase { +// CloneRefOfTableSpec creates a deep clone of the input. +func CloneRefOfTableSpec(n *TableSpec) *TableSpec { if n == nil { return nil } out := *n - out.Comments = CloneComments(n.Comments) + out.Columns = CloneSliceOfRefOfColumnDefinition(n.Columns) + out.Indexes = CloneSliceOfRefOfIndexDefinition(n.Indexes) + out.Constraints = CloneSliceOfRefOfConstraintDefinition(n.Constraints) + out.Options = CloneTableOptions(n.Options) return &out } -// VisitRefOfDropDatabase will visit all parts of the AST -func VisitRefOfDropDatabase(in *DropDatabase, f Visit) error { - if in == nil { +// CloneRefOfTablespaceOperation creates a deep clone of the input. +func CloneRefOfTablespaceOperation(n *TablespaceOperation) *TablespaceOperation { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitComments(in.Comments, f); err != nil { - return err - } - return nil + out := *n + return &out } -// EqualsRefOfDropKey does deep equals between the two objects. -func EqualsRefOfDropKey(a, b *DropKey) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfTimestampFuncExpr creates a deep clone of the input. +func CloneRefOfTimestampFuncExpr(n *TimestampFuncExpr) *TimestampFuncExpr { + if n == nil { + return nil } - return a.Name == b.Name && - a.Type == b.Type + out := *n + out.Expr1 = CloneExpr(n.Expr1) + out.Expr2 = CloneExpr(n.Expr2) + return &out } -// CloneRefOfDropKey creates a deep clone of the input. -func CloneRefOfDropKey(n *DropKey) *DropKey { +// CloneRefOfTruncateTable creates a deep clone of the input. +func CloneRefOfTruncateTable(n *TruncateTable) *TruncateTable { if n == nil { return nil } out := *n + out.Table = CloneTableName(n.Table) return &out } -// VisitRefOfDropKey will visit all parts of the AST -func VisitRefOfDropKey(in *DropKey, f Visit) error { - if in == nil { +// CloneRefOfUnaryExpr creates a deep clone of the input. +func CloneRefOfUnaryExpr(n *UnaryExpr) *UnaryExpr { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil + out := *n + out.Expr = CloneExpr(n.Expr) + return &out } -// EqualsRefOfDropTable does deep equals between the two objects. -func EqualsRefOfDropTable(a, b *DropTable) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +// CloneRefOfUnion creates a deep clone of the input. +func CloneRefOfUnion(n *Union) *Union { + if n == nil { + return nil } - return a.Temp == b.Temp && - a.IfExists == b.IfExists && - EqualsTableNames(a.FromTables, b.FromTables) + out := *n + out.FirstStatement = CloneSelectStatement(n.FirstStatement) + out.UnionSelects = CloneSliceOfRefOfUnionSelect(n.UnionSelects) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) + return &out } -// CloneRefOfDropTable creates a deep clone of the input. -func CloneRefOfDropTable(n *DropTable) *DropTable { +// CloneRefOfUnionSelect creates a deep clone of the input. +func CloneRefOfUnionSelect(n *UnionSelect) *UnionSelect { if n == nil { return nil } out := *n - out.FromTables = CloneTableNames(n.FromTables) + out.Statement = CloneSelectStatement(n.Statement) return &out } -// VisitRefOfDropTable will visit all parts of the AST -func VisitRefOfDropTable(in *DropTable, f Visit) error { - if in == nil { +// CloneRefOfUnlockTables creates a deep clone of the input. +func CloneRefOfUnlockTables(n *UnlockTables) *UnlockTables { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableNames(in.FromTables, f); err != nil { - return err - } - return nil + out := *n + return &out } -// EqualsRefOfDropView does deep equals between the two objects. -func EqualsRefOfDropView(a, b *DropView) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.IfExists == b.IfExists && - EqualsTableNames(a.FromTables, b.FromTables) -} - -// CloneRefOfDropView creates a deep clone of the input. -func CloneRefOfDropView(n *DropView) *DropView { +// CloneRefOfUpdate creates a deep clone of the input. +func CloneRefOfUpdate(n *Update) *Update { if n == nil { return nil } out := *n - out.FromTables = CloneTableNames(n.FromTables) + out.Comments = CloneComments(n.Comments) + out.TableExprs = CloneTableExprs(n.TableExprs) + out.Exprs = CloneUpdateExprs(n.Exprs) + out.Where = CloneRefOfWhere(n.Where) + out.OrderBy = CloneOrderBy(n.OrderBy) + out.Limit = CloneRefOfLimit(n.Limit) return &out } -// VisitRefOfDropView will visit all parts of the AST -func VisitRefOfDropView(in *DropView, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableNames(in.FromTables, f); err != nil { - return err - } - return nil -} - -// EqualsRefOfExistsExpr does deep equals between the two objects. -func EqualsRefOfExistsExpr(a, b *ExistsExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsRefOfSubquery(a.Subquery, b.Subquery) -} - -// CloneRefOfExistsExpr creates a deep clone of the input. -func CloneRefOfExistsExpr(n *ExistsExpr) *ExistsExpr { +// CloneRefOfUpdateExpr creates a deep clone of the input. +func CloneRefOfUpdateExpr(n *UpdateExpr) *UpdateExpr { if n == nil { return nil } out := *n - out.Subquery = CloneRefOfSubquery(n.Subquery) + out.Name = CloneRefOfColName(n.Name) + out.Expr = CloneExpr(n.Expr) return &out } -// VisitRefOfExistsExpr will visit all parts of the AST -func VisitRefOfExistsExpr(in *ExistsExpr, f Visit) error { - if in == nil { +// CloneRefOfUse creates a deep clone of the input. +func CloneRefOfUse(n *Use) *Use { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfSubquery(in.Subquery, f); err != nil { - return err - } - return nil -} - -// EqualsRefOfExplainStmt does deep equals between the two objects. -func EqualsRefOfExplainStmt(a, b *ExplainStmt) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Type == b.Type && - EqualsStatement(a.Statement, b.Statement) + out := *n + out.DBName = CloneTableIdent(n.DBName) + return &out } -// CloneRefOfExplainStmt creates a deep clone of the input. -func CloneRefOfExplainStmt(n *ExplainStmt) *ExplainStmt { +// CloneRefOfVStream creates a deep clone of the input. +func CloneRefOfVStream(n *VStream) *VStream { if n == nil { return nil } out := *n - out.Statement = CloneStatement(n.Statement) + out.Comments = CloneComments(n.Comments) + out.SelectExpr = CloneSelectExpr(n.SelectExpr) + out.Table = CloneTableName(n.Table) + out.Where = CloneRefOfWhere(n.Where) + out.Limit = CloneRefOfLimit(n.Limit) return &out } -// VisitRefOfExplainStmt will visit all parts of the AST -func VisitRefOfExplainStmt(in *ExplainStmt, f Visit) error { - if in == nil { +// CloneRefOfValidation creates a deep clone of the input. +func CloneRefOfValidation(n *Validation) *Validation { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitStatement(in.Statement, f); err != nil { - return err - } - return nil -} - -// EqualsRefOfExplainTab does deep equals between the two objects. -func EqualsRefOfExplainTab(a, b *ExplainTab) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Wild == b.Wild && - EqualsTableName(a.Table, b.Table) + out := *n + return &out } -// CloneRefOfExplainTab creates a deep clone of the input. -func CloneRefOfExplainTab(n *ExplainTab) *ExplainTab { +// CloneRefOfValuesFuncExpr creates a deep clone of the input. +func CloneRefOfValuesFuncExpr(n *ValuesFuncExpr) *ValuesFuncExpr { if n == nil { return nil } out := *n - out.Table = CloneTableName(n.Table) + out.Name = CloneRefOfColName(n.Name) return &out } -// VisitRefOfExplainTab will visit all parts of the AST -func VisitRefOfExplainTab(in *ExplainTab, f Visit) error { - if in == nil { +// CloneRefOfVindexParam creates a deep clone of the input. +func CloneRefOfVindexParam(n *VindexParam) *VindexParam { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableName(in.Table, f); err != nil { - return err - } - return nil -} - -// EqualsExprs does deep equals between the two objects. -func EqualsExprs(a, b Exprs) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsExpr(a[i], b[i]) { - return false - } - } - return true -} - -// CloneExprs creates a deep clone of the input. -func CloneExprs(n Exprs) Exprs { - res := make(Exprs, 0, len(n)) - for _, x := range n { - res = append(res, CloneExpr(x)) - } - return res + out := *n + out.Key = CloneColIdent(n.Key) + return &out } -// VisitExprs will visit all parts of the AST -func VisitExprs(in Exprs, f Visit) error { - if in == nil { +// CloneRefOfVindexSpec creates a deep clone of the input. +func CloneRefOfVindexSpec(n *VindexSpec) *VindexSpec { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - for _, el := range in { - if err := VisitExpr(el, f); err != nil { - return err - } - } - return nil -} - -// EqualsRefOfFlush does deep equals between the two objects. -func EqualsRefOfFlush(a, b *Flush) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.IsLocal == b.IsLocal && - a.WithLock == b.WithLock && - a.ForExport == b.ForExport && - EqualsSliceOfString(a.FlushOptions, b.FlushOptions) && - EqualsTableNames(a.TableNames, b.TableNames) + out := *n + out.Name = CloneColIdent(n.Name) + out.Type = CloneColIdent(n.Type) + out.Params = CloneSliceOfVindexParam(n.Params) + return &out } -// CloneRefOfFlush creates a deep clone of the input. -func CloneRefOfFlush(n *Flush) *Flush { +// CloneRefOfWhen creates a deep clone of the input. +func CloneRefOfWhen(n *When) *When { if n == nil { return nil } out := *n - out.FlushOptions = CloneSliceOfString(n.FlushOptions) - out.TableNames = CloneTableNames(n.TableNames) + out.Cond = CloneExpr(n.Cond) + out.Val = CloneExpr(n.Val) return &out } -// VisitRefOfFlush will visit all parts of the AST -func VisitRefOfFlush(in *Flush, f Visit) error { - if in == nil { +// CloneRefOfWhere creates a deep clone of the input. +func CloneRefOfWhere(n *Where) *Where { + if n == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableNames(in.TableNames, f); err != nil { - return err - } - return nil -} - -// EqualsRefOfForce does deep equals between the two objects. -func EqualsRefOfForce(a, b *Force) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return true + out := *n + out.Expr = CloneExpr(n.Expr) + return &out } -// CloneRefOfForce creates a deep clone of the input. -func CloneRefOfForce(n *Force) *Force { +// CloneRefOfXorExpr creates a deep clone of the input. +func CloneRefOfXorExpr(n *XorExpr) *XorExpr { if n == nil { return nil } out := *n + out.Left = CloneExpr(n.Left) + out.Right = CloneExpr(n.Right) return &out } -// VisitRefOfForce will visit all parts of the AST -func VisitRefOfForce(in *Force, f Visit) error { +// CloneSQLNode creates a deep clone of the input. +func CloneSQLNode(in SQLNode) SQLNode { if in == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil -} - -// EqualsRefOfForeignKeyDefinition does deep equals between the two objects. -func EqualsRefOfForeignKeyDefinition(a, b *ForeignKeyDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsColumns(a.Source, b.Source) && - EqualsTableName(a.ReferencedTable, b.ReferencedTable) && - EqualsColumns(a.ReferencedColumns, b.ReferencedColumns) && - a.OnDelete == b.OnDelete && - a.OnUpdate == b.OnUpdate -} - -// CloneRefOfForeignKeyDefinition creates a deep clone of the input. -func CloneRefOfForeignKeyDefinition(n *ForeignKeyDefinition) *ForeignKeyDefinition { - if n == nil { + switch in := in.(type) { + case AccessMode: + return in + case *AddColumns: + return CloneRefOfAddColumns(in) + case *AddConstraintDefinition: + return CloneRefOfAddConstraintDefinition(in) + case *AddIndexDefinition: + return CloneRefOfAddIndexDefinition(in) + case AlgorithmValue: + return in + case *AliasedExpr: + return CloneRefOfAliasedExpr(in) + case *AliasedTableExpr: + return CloneRefOfAliasedTableExpr(in) + case *AlterCharset: + return CloneRefOfAlterCharset(in) + case *AlterColumn: + return CloneRefOfAlterColumn(in) + case *AlterDatabase: + return CloneRefOfAlterDatabase(in) + case *AlterMigration: + return CloneRefOfAlterMigration(in) + case *AlterTable: + return CloneRefOfAlterTable(in) + case *AlterView: + return CloneRefOfAlterView(in) + case *AlterVschema: + return CloneRefOfAlterVschema(in) + case *AndExpr: + return CloneRefOfAndExpr(in) + case Argument: + return in + case *AutoIncSpec: + return CloneRefOfAutoIncSpec(in) + case *Begin: + return CloneRefOfBegin(in) + case *BinaryExpr: + return CloneRefOfBinaryExpr(in) + case BoolVal: + return in + case *CallProc: + return CloneRefOfCallProc(in) + case *CaseExpr: + return CloneRefOfCaseExpr(in) + case *ChangeColumn: + return CloneRefOfChangeColumn(in) + case *CheckConstraintDefinition: + return CloneRefOfCheckConstraintDefinition(in) + case ColIdent: + return CloneColIdent(in) + case *ColName: + return CloneRefOfColName(in) + case *CollateExpr: + return CloneRefOfCollateExpr(in) + case *ColumnDefinition: + return CloneRefOfColumnDefinition(in) + case *ColumnType: + return CloneRefOfColumnType(in) + case Columns: + return CloneColumns(in) + case Comments: + return CloneComments(in) + case *Commit: + return CloneRefOfCommit(in) + case *ComparisonExpr: + return CloneRefOfComparisonExpr(in) + case *ConstraintDefinition: + return CloneRefOfConstraintDefinition(in) + case *ConvertExpr: + return CloneRefOfConvertExpr(in) + case *ConvertType: + return CloneRefOfConvertType(in) + case *ConvertUsingExpr: + return CloneRefOfConvertUsingExpr(in) + case *CreateDatabase: + return CloneRefOfCreateDatabase(in) + case *CreateTable: + return CloneRefOfCreateTable(in) + case *CreateView: + return CloneRefOfCreateView(in) + case *CurTimeFuncExpr: + return CloneRefOfCurTimeFuncExpr(in) + case *Default: + return CloneRefOfDefault(in) + case *Delete: + return CloneRefOfDelete(in) + case *DerivedTable: + return CloneRefOfDerivedTable(in) + case *DropColumn: + return CloneRefOfDropColumn(in) + case *DropDatabase: + return CloneRefOfDropDatabase(in) + case *DropKey: + return CloneRefOfDropKey(in) + case *DropTable: + return CloneRefOfDropTable(in) + case *DropView: + return CloneRefOfDropView(in) + case *ExistsExpr: + return CloneRefOfExistsExpr(in) + case *ExplainStmt: + return CloneRefOfExplainStmt(in) + case *ExplainTab: + return CloneRefOfExplainTab(in) + case Exprs: + return CloneExprs(in) + case *Flush: + return CloneRefOfFlush(in) + case *Force: + return CloneRefOfForce(in) + case *ForeignKeyDefinition: + return CloneRefOfForeignKeyDefinition(in) + case *FuncExpr: + return CloneRefOfFuncExpr(in) + case GroupBy: + return CloneGroupBy(in) + case *GroupConcatExpr: + return CloneRefOfGroupConcatExpr(in) + case *IndexDefinition: + return CloneRefOfIndexDefinition(in) + case *IndexHints: + return CloneRefOfIndexHints(in) + case *IndexInfo: + return CloneRefOfIndexInfo(in) + case *Insert: + return CloneRefOfInsert(in) + case *IntervalExpr: + return CloneRefOfIntervalExpr(in) + case *IsExpr: + return CloneRefOfIsExpr(in) + case IsolationLevel: + return in + case JoinCondition: + return CloneJoinCondition(in) + case *JoinTableExpr: + return CloneRefOfJoinTableExpr(in) + case *KeyState: + return CloneRefOfKeyState(in) + case *Limit: + return CloneRefOfLimit(in) + case ListArg: + return CloneListArg(in) + case *Literal: + return CloneRefOfLiteral(in) + case *Load: + return CloneRefOfLoad(in) + case *LockOption: + return CloneRefOfLockOption(in) + case *LockTables: + return CloneRefOfLockTables(in) + case *MatchExpr: + return CloneRefOfMatchExpr(in) + case *ModifyColumn: + return CloneRefOfModifyColumn(in) + case *Nextval: + return CloneRefOfNextval(in) + case *NotExpr: + return CloneRefOfNotExpr(in) + case *NullVal: + return CloneRefOfNullVal(in) + case OnDup: + return CloneOnDup(in) + case *OptLike: + return CloneRefOfOptLike(in) + case *OrExpr: + return CloneRefOfOrExpr(in) + case *Order: + return CloneRefOfOrder(in) + case OrderBy: + return CloneOrderBy(in) + case *OrderByOption: + return CloneRefOfOrderByOption(in) + case *OtherAdmin: + return CloneRefOfOtherAdmin(in) + case *OtherRead: + return CloneRefOfOtherRead(in) + case *ParenSelect: + return CloneRefOfParenSelect(in) + case *ParenTableExpr: + return CloneRefOfParenTableExpr(in) + case *PartitionDefinition: + return CloneRefOfPartitionDefinition(in) + case *PartitionSpec: + return CloneRefOfPartitionSpec(in) + case Partitions: + return ClonePartitions(in) + case *RangeCond: + return CloneRefOfRangeCond(in) + case ReferenceAction: + return in + case *Release: + return CloneRefOfRelease(in) + case *RenameIndex: + return CloneRefOfRenameIndex(in) + case *RenameTable: + return CloneRefOfRenameTable(in) + case *RenameTableName: + return CloneRefOfRenameTableName(in) + case *RevertMigration: + return CloneRefOfRevertMigration(in) + case *Rollback: + return CloneRefOfRollback(in) + case *SRollback: + return CloneRefOfSRollback(in) + case *Savepoint: + return CloneRefOfSavepoint(in) + case *Select: + return CloneRefOfSelect(in) + case SelectExprs: + return CloneSelectExprs(in) + case *SelectInto: + return CloneRefOfSelectInto(in) + case *Set: + return CloneRefOfSet(in) + case *SetExpr: + return CloneRefOfSetExpr(in) + case SetExprs: + return CloneSetExprs(in) + case *SetTransaction: + return CloneRefOfSetTransaction(in) + case *Show: + return CloneRefOfShow(in) + case *ShowBasic: + return CloneRefOfShowBasic(in) + case *ShowCreate: + return CloneRefOfShowCreate(in) + case *ShowFilter: + return CloneRefOfShowFilter(in) + case *ShowLegacy: + return CloneRefOfShowLegacy(in) + case *StarExpr: + return CloneRefOfStarExpr(in) + case *Stream: + return CloneRefOfStream(in) + case *Subquery: + return CloneRefOfSubquery(in) + case *SubstrExpr: + return CloneRefOfSubstrExpr(in) + case TableExprs: + return CloneTableExprs(in) + case TableIdent: + return CloneTableIdent(in) + case TableName: + return CloneTableName(in) + case TableNames: + return CloneTableNames(in) + case TableOptions: + return CloneTableOptions(in) + case *TableSpec: + return CloneRefOfTableSpec(in) + case *TablespaceOperation: + return CloneRefOfTablespaceOperation(in) + case *TimestampFuncExpr: + return CloneRefOfTimestampFuncExpr(in) + case *TruncateTable: + return CloneRefOfTruncateTable(in) + case *UnaryExpr: + return CloneRefOfUnaryExpr(in) + case *Union: + return CloneRefOfUnion(in) + case *UnionSelect: + return CloneRefOfUnionSelect(in) + case *UnlockTables: + return CloneRefOfUnlockTables(in) + case *Update: + return CloneRefOfUpdate(in) + case *UpdateExpr: + return CloneRefOfUpdateExpr(in) + case UpdateExprs: + return CloneUpdateExprs(in) + case *Use: + return CloneRefOfUse(in) + case *VStream: + return CloneRefOfVStream(in) + case ValTuple: + return CloneValTuple(in) + case *Validation: + return CloneRefOfValidation(in) + case Values: + return CloneValues(in) + case *ValuesFuncExpr: + return CloneRefOfValuesFuncExpr(in) + case VindexParam: + return CloneVindexParam(in) + case *VindexSpec: + return CloneRefOfVindexSpec(in) + case *When: + return CloneRefOfWhen(in) + case *Where: + return CloneRefOfWhere(in) + case *XorExpr: + return CloneRefOfXorExpr(in) + default: + // this should never happen + return nil + } +} + +// CloneSelectExpr creates a deep clone of the input. +func CloneSelectExpr(in SelectExpr) SelectExpr { + if in == nil { + return nil + } + switch in := in.(type) { + case *AliasedExpr: + return CloneRefOfAliasedExpr(in) + case *Nextval: + return CloneRefOfNextval(in) + case *StarExpr: + return CloneRefOfStarExpr(in) + default: + // this should never happen + return nil + } +} + +// CloneSelectExprs creates a deep clone of the input. +func CloneSelectExprs(n SelectExprs) SelectExprs { + res := make(SelectExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneSelectExpr(x)) + } + return res +} + +// CloneSelectStatement creates a deep clone of the input. +func CloneSelectStatement(in SelectStatement) SelectStatement { + if in == nil { + return nil + } + switch in := in.(type) { + case *ParenSelect: + return CloneRefOfParenSelect(in) + case *Select: + return CloneRefOfSelect(in) + case *Union: + return CloneRefOfUnion(in) + default: + // this should never happen + return nil + } +} + +// CloneSetExprs creates a deep clone of the input. +func CloneSetExprs(n SetExprs) SetExprs { + res := make(SetExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfSetExpr(x)) + } + return res +} + +// CloneShowInternal creates a deep clone of the input. +func CloneShowInternal(in ShowInternal) ShowInternal { + if in == nil { + return nil + } + switch in := in.(type) { + case *ShowBasic: + return CloneRefOfShowBasic(in) + case *ShowCreate: + return CloneRefOfShowCreate(in) + case *ShowLegacy: + return CloneRefOfShowLegacy(in) + default: + // this should never happen + return nil + } +} + +// CloneSimpleTableExpr creates a deep clone of the input. +func CloneSimpleTableExpr(in SimpleTableExpr) SimpleTableExpr { + if in == nil { + return nil + } + switch in := in.(type) { + case *DerivedTable: + return CloneRefOfDerivedTable(in) + case TableName: + return CloneTableName(in) + default: + // this should never happen + return nil + } +} + +// CloneSliceOfAlterOption creates a deep clone of the input. +func CloneSliceOfAlterOption(n []AlterOption) []AlterOption { + res := make([]AlterOption, 0, len(n)) + for _, x := range n { + res = append(res, CloneAlterOption(x)) + } + return res +} + +// CloneSliceOfCharacteristic creates a deep clone of the input. +func CloneSliceOfCharacteristic(n []Characteristic) []Characteristic { + res := make([]Characteristic, 0, len(n)) + for _, x := range n { + res = append(res, CloneCharacteristic(x)) + } + return res +} + +// CloneSliceOfColIdent creates a deep clone of the input. +func CloneSliceOfColIdent(n []ColIdent) []ColIdent { + res := make([]ColIdent, 0, len(n)) + for _, x := range n { + res = append(res, CloneColIdent(x)) + } + return res +} + +// CloneSliceOfCollateAndCharset creates a deep clone of the input. +func CloneSliceOfCollateAndCharset(n []CollateAndCharset) []CollateAndCharset { + res := make([]CollateAndCharset, 0, len(n)) + for _, x := range n { + res = append(res, CloneCollateAndCharset(x)) + } + return res +} + +// CloneSliceOfRefOfColumnDefinition creates a deep clone of the input. +func CloneSliceOfRefOfColumnDefinition(n []*ColumnDefinition) []*ColumnDefinition { + res := make([]*ColumnDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfColumnDefinition(x)) + } + return res +} + +// CloneSliceOfRefOfConstraintDefinition creates a deep clone of the input. +func CloneSliceOfRefOfConstraintDefinition(n []*ConstraintDefinition) []*ConstraintDefinition { + res := make([]*ConstraintDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfConstraintDefinition(x)) + } + return res +} + +// CloneSliceOfRefOfIndexColumn creates a deep clone of the input. +func CloneSliceOfRefOfIndexColumn(n []*IndexColumn) []*IndexColumn { + res := make([]*IndexColumn, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfIndexColumn(x)) + } + return res +} + +// CloneSliceOfRefOfIndexDefinition creates a deep clone of the input. +func CloneSliceOfRefOfIndexDefinition(n []*IndexDefinition) []*IndexDefinition { + res := make([]*IndexDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfIndexDefinition(x)) + } + return res +} + +// CloneSliceOfRefOfIndexOption creates a deep clone of the input. +func CloneSliceOfRefOfIndexOption(n []*IndexOption) []*IndexOption { + res := make([]*IndexOption, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfIndexOption(x)) + } + return res +} + +// CloneSliceOfRefOfPartitionDefinition creates a deep clone of the input. +func CloneSliceOfRefOfPartitionDefinition(n []*PartitionDefinition) []*PartitionDefinition { + res := make([]*PartitionDefinition, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfPartitionDefinition(x)) + } + return res +} + +// CloneSliceOfRefOfRenameTablePair creates a deep clone of the input. +func CloneSliceOfRefOfRenameTablePair(n []*RenameTablePair) []*RenameTablePair { + res := make([]*RenameTablePair, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfRenameTablePair(x)) + } + return res +} + +// CloneSliceOfRefOfUnionSelect creates a deep clone of the input. +func CloneSliceOfRefOfUnionSelect(n []*UnionSelect) []*UnionSelect { + res := make([]*UnionSelect, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfUnionSelect(x)) + } + return res +} + +// CloneSliceOfRefOfWhen creates a deep clone of the input. +func CloneSliceOfRefOfWhen(n []*When) []*When { + res := make([]*When, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfWhen(x)) + } + return res +} + +// CloneSliceOfString creates a deep clone of the input. +func CloneSliceOfString(n []string) []string { + res := make([]string, 0, len(n)) + copy(res, n) + return res +} + +// CloneSliceOfVindexParam creates a deep clone of the input. +func CloneSliceOfVindexParam(n []VindexParam) []VindexParam { + res := make([]VindexParam, 0, len(n)) + for _, x := range n { + res = append(res, CloneVindexParam(x)) + } + return res +} + +// CloneStatement creates a deep clone of the input. +func CloneStatement(in Statement) Statement { + if in == nil { + return nil + } + switch in := in.(type) { + case *AlterDatabase: + return CloneRefOfAlterDatabase(in) + case *AlterMigration: + return CloneRefOfAlterMigration(in) + case *AlterTable: + return CloneRefOfAlterTable(in) + case *AlterView: + return CloneRefOfAlterView(in) + case *AlterVschema: + return CloneRefOfAlterVschema(in) + case *Begin: + return CloneRefOfBegin(in) + case *CallProc: + return CloneRefOfCallProc(in) + case *Commit: + return CloneRefOfCommit(in) + case *CreateDatabase: + return CloneRefOfCreateDatabase(in) + case *CreateTable: + return CloneRefOfCreateTable(in) + case *CreateView: + return CloneRefOfCreateView(in) + case *Delete: + return CloneRefOfDelete(in) + case *DropDatabase: + return CloneRefOfDropDatabase(in) + case *DropTable: + return CloneRefOfDropTable(in) + case *DropView: + return CloneRefOfDropView(in) + case *ExplainStmt: + return CloneRefOfExplainStmt(in) + case *ExplainTab: + return CloneRefOfExplainTab(in) + case *Flush: + return CloneRefOfFlush(in) + case *Insert: + return CloneRefOfInsert(in) + case *Load: + return CloneRefOfLoad(in) + case *LockTables: + return CloneRefOfLockTables(in) + case *OtherAdmin: + return CloneRefOfOtherAdmin(in) + case *OtherRead: + return CloneRefOfOtherRead(in) + case *ParenSelect: + return CloneRefOfParenSelect(in) + case *Release: + return CloneRefOfRelease(in) + case *RenameTable: + return CloneRefOfRenameTable(in) + case *RevertMigration: + return CloneRefOfRevertMigration(in) + case *Rollback: + return CloneRefOfRollback(in) + case *SRollback: + return CloneRefOfSRollback(in) + case *Savepoint: + return CloneRefOfSavepoint(in) + case *Select: + return CloneRefOfSelect(in) + case *Set: + return CloneRefOfSet(in) + case *SetTransaction: + return CloneRefOfSetTransaction(in) + case *Show: + return CloneRefOfShow(in) + case *Stream: + return CloneRefOfStream(in) + case *TruncateTable: + return CloneRefOfTruncateTable(in) + case *Union: + return CloneRefOfUnion(in) + case *UnlockTables: + return CloneRefOfUnlockTables(in) + case *Update: + return CloneRefOfUpdate(in) + case *Use: + return CloneRefOfUse(in) + case *VStream: + return CloneRefOfVStream(in) + default: + // this should never happen return nil } - out := *n - out.Source = CloneColumns(n.Source) - out.ReferencedTable = CloneTableName(n.ReferencedTable) - out.ReferencedColumns = CloneColumns(n.ReferencedColumns) - return &out } -// VisitRefOfForeignKeyDefinition will visit all parts of the AST -func VisitRefOfForeignKeyDefinition(in *ForeignKeyDefinition, f Visit) error { - if in == nil { - return nil +// CloneTableAndLockTypes creates a deep clone of the input. +func CloneTableAndLockTypes(n TableAndLockTypes) TableAndLockTypes { + res := make(TableAndLockTypes, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfTableAndLockType(x)) + } + return res +} + +// CloneTableExpr creates a deep clone of the input. +func CloneTableExpr(in TableExpr) TableExpr { + if in == nil { + return nil + } + switch in := in.(type) { + case *AliasedTableExpr: + return CloneRefOfAliasedTableExpr(in) + case *JoinTableExpr: + return CloneRefOfJoinTableExpr(in) + case *ParenTableExpr: + return CloneRefOfParenTableExpr(in) + default: + // this should never happen + return nil + } +} + +// CloneTableExprs creates a deep clone of the input. +func CloneTableExprs(n TableExprs) TableExprs { + res := make(TableExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneTableExpr(x)) + } + return res +} + +// CloneTableIdent creates a deep clone of the input. +func CloneTableIdent(n TableIdent) TableIdent { + return *CloneRefOfTableIdent(&n) +} + +// CloneTableName creates a deep clone of the input. +func CloneTableName(n TableName) TableName { + return *CloneRefOfTableName(&n) +} + +// CloneTableNames creates a deep clone of the input. +func CloneTableNames(n TableNames) TableNames { + res := make(TableNames, 0, len(n)) + for _, x := range n { + res = append(res, CloneTableName(x)) + } + return res +} + +// CloneTableOptions creates a deep clone of the input. +func CloneTableOptions(n TableOptions) TableOptions { + res := make(TableOptions, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfTableOption(x)) + } + return res +} + +// CloneUpdateExprs creates a deep clone of the input. +func CloneUpdateExprs(n UpdateExprs) UpdateExprs { + res := make(UpdateExprs, 0, len(n)) + for _, x := range n { + res = append(res, CloneRefOfUpdateExpr(x)) + } + return res +} + +// CloneValTuple creates a deep clone of the input. +func CloneValTuple(n ValTuple) ValTuple { + res := make(ValTuple, 0, len(n)) + for _, x := range n { + res = append(res, CloneExpr(x)) + } + return res +} + +// CloneValues creates a deep clone of the input. +func CloneValues(n Values) Values { + res := make(Values, 0, len(n)) + for _, x := range n { + res = append(res, CloneValTuple(x)) + } + return res +} + +// CloneVindexParam creates a deep clone of the input. +func CloneVindexParam(n VindexParam) VindexParam { + return *CloneRefOfVindexParam(&n) +} + +// EqualsAlterOption does deep equals between the two objects. +func EqualsAlterOption(inA, inB AlterOption) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *AddColumns: + b, ok := inB.(*AddColumns) + if !ok { + return false + } + return EqualsRefOfAddColumns(a, b) + case *AddConstraintDefinition: + b, ok := inB.(*AddConstraintDefinition) + if !ok { + return false + } + return EqualsRefOfAddConstraintDefinition(a, b) + case *AddIndexDefinition: + b, ok := inB.(*AddIndexDefinition) + if !ok { + return false + } + return EqualsRefOfAddIndexDefinition(a, b) + case AlgorithmValue: + b, ok := inB.(AlgorithmValue) + if !ok { + return false + } + return a == b + case *AlterCharset: + b, ok := inB.(*AlterCharset) + if !ok { + return false + } + return EqualsRefOfAlterCharset(a, b) + case *AlterColumn: + b, ok := inB.(*AlterColumn) + if !ok { + return false + } + return EqualsRefOfAlterColumn(a, b) + case *ChangeColumn: + b, ok := inB.(*ChangeColumn) + if !ok { + return false + } + return EqualsRefOfChangeColumn(a, b) + case *DropColumn: + b, ok := inB.(*DropColumn) + if !ok { + return false + } + return EqualsRefOfDropColumn(a, b) + case *DropKey: + b, ok := inB.(*DropKey) + if !ok { + return false + } + return EqualsRefOfDropKey(a, b) + case *Force: + b, ok := inB.(*Force) + if !ok { + return false + } + return EqualsRefOfForce(a, b) + case *KeyState: + b, ok := inB.(*KeyState) + if !ok { + return false + } + return EqualsRefOfKeyState(a, b) + case *LockOption: + b, ok := inB.(*LockOption) + if !ok { + return false + } + return EqualsRefOfLockOption(a, b) + case *ModifyColumn: + b, ok := inB.(*ModifyColumn) + if !ok { + return false + } + return EqualsRefOfModifyColumn(a, b) + case *OrderByOption: + b, ok := inB.(*OrderByOption) + if !ok { + return false + } + return EqualsRefOfOrderByOption(a, b) + case *RenameIndex: + b, ok := inB.(*RenameIndex) + if !ok { + return false + } + return EqualsRefOfRenameIndex(a, b) + case *RenameTableName: + b, ok := inB.(*RenameTableName) + if !ok { + return false + } + return EqualsRefOfRenameTableName(a, b) + case TableOptions: + b, ok := inB.(TableOptions) + if !ok { + return false + } + return EqualsTableOptions(a, b) + case *TablespaceOperation: + b, ok := inB.(*TablespaceOperation) + if !ok { + return false + } + return EqualsRefOfTablespaceOperation(a, b) + case *Validation: + b, ok := inB.(*Validation) + if !ok { + return false + } + return EqualsRefOfValidation(a, b) + default: + // this should never happen + return false + } +} + +// EqualsCharacteristic does deep equals between the two objects. +func EqualsCharacteristic(inA, inB Characteristic) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case AccessMode: + b, ok := inB.(AccessMode) + if !ok { + return false + } + return a == b + case IsolationLevel: + b, ok := inB.(IsolationLevel) + if !ok { + return false + } + return a == b + default: + // this should never happen + return false + } +} + +// EqualsColIdent does deep equals between the two objects. +func EqualsColIdent(a, b ColIdent) bool { + return a.val == b.val && + a.lowered == b.lowered && + a.at == b.at +} + +// EqualsColTuple does deep equals between the two objects. +func EqualsColTuple(inA, inB ColTuple) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case ListArg: + b, ok := inB.(ListArg) + if !ok { + return false + } + return EqualsListArg(a, b) + case *Subquery: + b, ok := inB.(*Subquery) + if !ok { + return false + } + return EqualsRefOfSubquery(a, b) + case ValTuple: + b, ok := inB.(ValTuple) + if !ok { + return false + } + return EqualsValTuple(a, b) + default: + // this should never happen + return false + } +} + +// EqualsCollateAndCharset does deep equals between the two objects. +func EqualsCollateAndCharset(a, b CollateAndCharset) bool { + return a.IsDefault == b.IsDefault && + a.Value == b.Value && + a.Type == b.Type +} + +// EqualsColumnType does deep equals between the two objects. +func EqualsColumnType(a, b ColumnType) bool { + return a.Type == b.Type && + a.Unsigned == b.Unsigned && + a.Zerofill == b.Zerofill && + a.Charset == b.Charset && + a.Collate == b.Collate && + EqualsRefOfColumnTypeOptions(a.Options, b.Options) && + EqualsRefOfLiteral(a.Length, b.Length) && + EqualsRefOfLiteral(a.Scale, b.Scale) && + EqualsSliceOfString(a.EnumValues, b.EnumValues) +} + +// EqualsColumns does deep equals between the two objects. +func EqualsColumns(a, b Columns) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsColIdent(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsComments does deep equals between the two objects. +func EqualsComments(a, b Comments) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + return true +} + +// EqualsConstraintInfo does deep equals between the two objects. +func EqualsConstraintInfo(inA, inB ConstraintInfo) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *CheckConstraintDefinition: + b, ok := inB.(*CheckConstraintDefinition) + if !ok { + return false + } + return EqualsRefOfCheckConstraintDefinition(a, b) + case *ForeignKeyDefinition: + b, ok := inB.(*ForeignKeyDefinition) + if !ok { + return false + } + return EqualsRefOfForeignKeyDefinition(a, b) + default: + // this should never happen + return false + } +} + +// EqualsDBDDLStatement does deep equals between the two objects. +func EqualsDBDDLStatement(inA, inB DBDDLStatement) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *AlterDatabase: + b, ok := inB.(*AlterDatabase) + if !ok { + return false + } + return EqualsRefOfAlterDatabase(a, b) + case *CreateDatabase: + b, ok := inB.(*CreateDatabase) + if !ok { + return false + } + return EqualsRefOfCreateDatabase(a, b) + case *DropDatabase: + b, ok := inB.(*DropDatabase) + if !ok { + return false + } + return EqualsRefOfDropDatabase(a, b) + default: + // this should never happen + return false + } +} + +// EqualsDDLStatement does deep equals between the two objects. +func EqualsDDLStatement(inA, inB DDLStatement) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *AlterTable: + b, ok := inB.(*AlterTable) + if !ok { + return false + } + return EqualsRefOfAlterTable(a, b) + case *AlterView: + b, ok := inB.(*AlterView) + if !ok { + return false + } + return EqualsRefOfAlterView(a, b) + case *CreateTable: + b, ok := inB.(*CreateTable) + if !ok { + return false + } + return EqualsRefOfCreateTable(a, b) + case *CreateView: + b, ok := inB.(*CreateView) + if !ok { + return false + } + return EqualsRefOfCreateView(a, b) + case *DropTable: + b, ok := inB.(*DropTable) + if !ok { + return false + } + return EqualsRefOfDropTable(a, b) + case *DropView: + b, ok := inB.(*DropView) + if !ok { + return false + } + return EqualsRefOfDropView(a, b) + case *RenameTable: + b, ok := inB.(*RenameTable) + if !ok { + return false + } + return EqualsRefOfRenameTable(a, b) + case *TruncateTable: + b, ok := inB.(*TruncateTable) + if !ok { + return false + } + return EqualsRefOfTruncateTable(a, b) + default: + // this should never happen + return false + } +} + +// EqualsExplain does deep equals between the two objects. +func EqualsExplain(inA, inB Explain) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *ExplainStmt: + b, ok := inB.(*ExplainStmt) + if !ok { + return false + } + return EqualsRefOfExplainStmt(a, b) + case *ExplainTab: + b, ok := inB.(*ExplainTab) + if !ok { + return false + } + return EqualsRefOfExplainTab(a, b) + default: + // this should never happen + return false + } +} + +// EqualsExpr does deep equals between the two objects. +func EqualsExpr(inA, inB Expr) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *AndExpr: + b, ok := inB.(*AndExpr) + if !ok { + return false + } + return EqualsRefOfAndExpr(a, b) + case Argument: + b, ok := inB.(Argument) + if !ok { + return false + } + return a == b + case *BinaryExpr: + b, ok := inB.(*BinaryExpr) + if !ok { + return false + } + return EqualsRefOfBinaryExpr(a, b) + case BoolVal: + b, ok := inB.(BoolVal) + if !ok { + return false + } + return a == b + case *CaseExpr: + b, ok := inB.(*CaseExpr) + if !ok { + return false + } + return EqualsRefOfCaseExpr(a, b) + case *ColName: + b, ok := inB.(*ColName) + if !ok { + return false + } + return EqualsRefOfColName(a, b) + case *CollateExpr: + b, ok := inB.(*CollateExpr) + if !ok { + return false + } + return EqualsRefOfCollateExpr(a, b) + case *ComparisonExpr: + b, ok := inB.(*ComparisonExpr) + if !ok { + return false + } + return EqualsRefOfComparisonExpr(a, b) + case *ConvertExpr: + b, ok := inB.(*ConvertExpr) + if !ok { + return false + } + return EqualsRefOfConvertExpr(a, b) + case *ConvertUsingExpr: + b, ok := inB.(*ConvertUsingExpr) + if !ok { + return false + } + return EqualsRefOfConvertUsingExpr(a, b) + case *CurTimeFuncExpr: + b, ok := inB.(*CurTimeFuncExpr) + if !ok { + return false + } + return EqualsRefOfCurTimeFuncExpr(a, b) + case *Default: + b, ok := inB.(*Default) + if !ok { + return false + } + return EqualsRefOfDefault(a, b) + case *ExistsExpr: + b, ok := inB.(*ExistsExpr) + if !ok { + return false + } + return EqualsRefOfExistsExpr(a, b) + case *FuncExpr: + b, ok := inB.(*FuncExpr) + if !ok { + return false + } + return EqualsRefOfFuncExpr(a, b) + case *GroupConcatExpr: + b, ok := inB.(*GroupConcatExpr) + if !ok { + return false + } + return EqualsRefOfGroupConcatExpr(a, b) + case *IntervalExpr: + b, ok := inB.(*IntervalExpr) + if !ok { + return false + } + return EqualsRefOfIntervalExpr(a, b) + case *IsExpr: + b, ok := inB.(*IsExpr) + if !ok { + return false + } + return EqualsRefOfIsExpr(a, b) + case ListArg: + b, ok := inB.(ListArg) + if !ok { + return false + } + return EqualsListArg(a, b) + case *Literal: + b, ok := inB.(*Literal) + if !ok { + return false + } + return EqualsRefOfLiteral(a, b) + case *MatchExpr: + b, ok := inB.(*MatchExpr) + if !ok { + return false + } + return EqualsRefOfMatchExpr(a, b) + case *NotExpr: + b, ok := inB.(*NotExpr) + if !ok { + return false + } + return EqualsRefOfNotExpr(a, b) + case *NullVal: + b, ok := inB.(*NullVal) + if !ok { + return false + } + return EqualsRefOfNullVal(a, b) + case *OrExpr: + b, ok := inB.(*OrExpr) + if !ok { + return false + } + return EqualsRefOfOrExpr(a, b) + case *RangeCond: + b, ok := inB.(*RangeCond) + if !ok { + return false + } + return EqualsRefOfRangeCond(a, b) + case *Subquery: + b, ok := inB.(*Subquery) + if !ok { + return false + } + return EqualsRefOfSubquery(a, b) + case *SubstrExpr: + b, ok := inB.(*SubstrExpr) + if !ok { + return false + } + return EqualsRefOfSubstrExpr(a, b) + case *TimestampFuncExpr: + b, ok := inB.(*TimestampFuncExpr) + if !ok { + return false + } + return EqualsRefOfTimestampFuncExpr(a, b) + case *UnaryExpr: + b, ok := inB.(*UnaryExpr) + if !ok { + return false + } + return EqualsRefOfUnaryExpr(a, b) + case ValTuple: + b, ok := inB.(ValTuple) + if !ok { + return false + } + return EqualsValTuple(a, b) + case *ValuesFuncExpr: + b, ok := inB.(*ValuesFuncExpr) + if !ok { + return false + } + return EqualsRefOfValuesFuncExpr(a, b) + case *XorExpr: + b, ok := inB.(*XorExpr) + if !ok { + return false + } + return EqualsRefOfXorExpr(a, b) + default: + // this should never happen + return false + } +} + +// EqualsExprs does deep equals between the two objects. +func EqualsExprs(a, b Exprs) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsExpr(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsGroupBy does deep equals between the two objects. +func EqualsGroupBy(a, b GroupBy) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsExpr(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsInsertRows does deep equals between the two objects. +func EqualsInsertRows(inA, inB InsertRows) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *ParenSelect: + b, ok := inB.(*ParenSelect) + if !ok { + return false + } + return EqualsRefOfParenSelect(a, b) + case *Select: + b, ok := inB.(*Select) + if !ok { + return false + } + return EqualsRefOfSelect(a, b) + case *Union: + b, ok := inB.(*Union) + if !ok { + return false + } + return EqualsRefOfUnion(a, b) + case Values: + b, ok := inB.(Values) + if !ok { + return false + } + return EqualsValues(a, b) + default: + // this should never happen + return false + } +} + +// EqualsJoinCondition does deep equals between the two objects. +func EqualsJoinCondition(a, b JoinCondition) bool { + return EqualsExpr(a.On, b.On) && + EqualsColumns(a.Using, b.Using) +} + +// EqualsListArg does deep equals between the two objects. +func EqualsListArg(a, b ListArg) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + return true +} + +// EqualsOnDup does deep equals between the two objects. +func EqualsOnDup(a, b OnDup) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfUpdateExpr(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsOrderBy does deep equals between the two objects. +func EqualsOrderBy(a, b OrderBy) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfOrder(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsPartitions does deep equals between the two objects. +func EqualsPartitions(a, b Partitions) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsColIdent(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsRefOfAddColumns does deep equals between the two objects. +func EqualsRefOfAddColumns(a, b *AddColumns) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSliceOfRefOfColumnDefinition(a.Columns, b.Columns) && + EqualsRefOfColName(a.First, b.First) && + EqualsRefOfColName(a.After, b.After) +} + +// EqualsRefOfAddConstraintDefinition does deep equals between the two objects. +func EqualsRefOfAddConstraintDefinition(a, b *AddConstraintDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfConstraintDefinition(a.ConstraintDefinition, b.ConstraintDefinition) +} + +// EqualsRefOfAddIndexDefinition does deep equals between the two objects. +func EqualsRefOfAddIndexDefinition(a, b *AddIndexDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfIndexDefinition(a.IndexDefinition, b.IndexDefinition) +} + +// EqualsRefOfAliasedExpr does deep equals between the two objects. +func EqualsRefOfAliasedExpr(a, b *AliasedExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Expr, b.Expr) && + EqualsColIdent(a.As, b.As) +} + +// EqualsRefOfAliasedTableExpr does deep equals between the two objects. +func EqualsRefOfAliasedTableExpr(a, b *AliasedTableExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSimpleTableExpr(a.Expr, b.Expr) && + EqualsPartitions(a.Partitions, b.Partitions) && + EqualsTableIdent(a.As, b.As) && + EqualsRefOfIndexHints(a.Hints, b.Hints) +} + +// EqualsRefOfAlterCharset does deep equals between the two objects. +func EqualsRefOfAlterCharset(a, b *AlterCharset) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.CharacterSet == b.CharacterSet && + a.Collate == b.Collate +} + +// EqualsRefOfAlterColumn does deep equals between the two objects. +func EqualsRefOfAlterColumn(a, b *AlterColumn) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.DropDefault == b.DropDefault && + EqualsRefOfColName(a.Column, b.Column) && + EqualsExpr(a.DefaultVal, b.DefaultVal) +} + +// EqualsRefOfAlterDatabase does deep equals between the two objects. +func EqualsRefOfAlterDatabase(a, b *AlterDatabase) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.DBName == b.DBName && + a.UpdateDataDirectory == b.UpdateDataDirectory && + a.FullyParsed == b.FullyParsed && + EqualsSliceOfCollateAndCharset(a.AlterOptions, b.AlterOptions) +} + +// EqualsRefOfAlterMigration does deep equals between the two objects. +func EqualsRefOfAlterMigration(a, b *AlterMigration) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.UUID == b.UUID && + a.Type == b.Type +} + +// EqualsRefOfAlterTable does deep equals between the two objects. +func EqualsRefOfAlterTable(a, b *AlterTable) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.FullyParsed == b.FullyParsed && + EqualsTableName(a.Table, b.Table) && + EqualsSliceOfAlterOption(a.AlterOptions, b.AlterOptions) && + EqualsRefOfPartitionSpec(a.PartitionSpec, b.PartitionSpec) +} + +// EqualsRefOfAlterView does deep equals between the two objects. +func EqualsRefOfAlterView(a, b *AlterView) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Algorithm == b.Algorithm && + a.Definer == b.Definer && + a.Security == b.Security && + a.CheckOption == b.CheckOption && + EqualsTableName(a.ViewName, b.ViewName) && + EqualsColumns(a.Columns, b.Columns) && + EqualsSelectStatement(a.Select, b.Select) +} + +// EqualsRefOfAlterVschema does deep equals between the two objects. +func EqualsRefOfAlterVschema(a, b *AlterVschema) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Action == b.Action && + EqualsTableName(a.Table, b.Table) && + EqualsRefOfVindexSpec(a.VindexSpec, b.VindexSpec) && + EqualsSliceOfColIdent(a.VindexCols, b.VindexCols) && + EqualsRefOfAutoIncSpec(a.AutoIncSpec, b.AutoIncSpec) +} + +// EqualsRefOfAndExpr does deep equals between the two objects. +func EqualsRefOfAndExpr(a, b *AndExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.Right, b.Right) +} + +// EqualsRefOfAutoIncSpec does deep equals between the two objects. +func EqualsRefOfAutoIncSpec(a, b *AutoIncSpec) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Column, b.Column) && + EqualsTableName(a.Sequence, b.Sequence) +} + +// EqualsRefOfBegin does deep equals between the two objects. +func EqualsRefOfBegin(a, b *Begin) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true +} + +// EqualsRefOfBinaryExpr does deep equals between the two objects. +func EqualsRefOfBinaryExpr(a, b *BinaryExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Operator == b.Operator && + EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.Right, b.Right) +} + +// EqualsRefOfBool does deep equals between the two objects. +func EqualsRefOfBool(a, b *bool) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +// EqualsRefOfCallProc does deep equals between the two objects. +func EqualsRefOfCallProc(a, b *CallProc) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableName(a.Name, b.Name) && + EqualsExprs(a.Params, b.Params) +} + +// EqualsRefOfCaseExpr does deep equals between the two objects. +func EqualsRefOfCaseExpr(a, b *CaseExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Expr, b.Expr) && + EqualsSliceOfRefOfWhen(a.Whens, b.Whens) && + EqualsExpr(a.Else, b.Else) +} + +// EqualsRefOfChangeColumn does deep equals between the two objects. +func EqualsRefOfChangeColumn(a, b *ChangeColumn) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfColName(a.OldColumn, b.OldColumn) && + EqualsRefOfColumnDefinition(a.NewColDefinition, b.NewColDefinition) && + EqualsRefOfColName(a.First, b.First) && + EqualsRefOfColName(a.After, b.After) +} + +// EqualsRefOfCheckConstraintDefinition does deep equals between the two objects. +func EqualsRefOfCheckConstraintDefinition(a, b *CheckConstraintDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Enforced == b.Enforced && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfColIdent does deep equals between the two objects. +func EqualsRefOfColIdent(a, b *ColIdent) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.val == b.val && + a.lowered == b.lowered && + a.at == b.at +} + +// EqualsRefOfColName does deep equals between the two objects. +func EqualsRefOfColName(a, b *ColName) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Name, b.Name) && + EqualsTableName(a.Qualifier, b.Qualifier) +} + +// EqualsRefOfCollateAndCharset does deep equals between the two objects. +func EqualsRefOfCollateAndCharset(a, b *CollateAndCharset) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.IsDefault == b.IsDefault && + a.Value == b.Value && + a.Type == b.Type +} + +// EqualsRefOfCollateExpr does deep equals between the two objects. +func EqualsRefOfCollateExpr(a, b *CollateExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Charset == b.Charset && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfColumnDefinition does deep equals between the two objects. +func EqualsRefOfColumnDefinition(a, b *ColumnDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Name, b.Name) && + EqualsColumnType(a.Type, b.Type) +} + +// EqualsRefOfColumnType does deep equals between the two objects. +func EqualsRefOfColumnType(a, b *ColumnType) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && + a.Unsigned == b.Unsigned && + a.Zerofill == b.Zerofill && + a.Charset == b.Charset && + a.Collate == b.Collate && + EqualsRefOfColumnTypeOptions(a.Options, b.Options) && + EqualsRefOfLiteral(a.Length, b.Length) && + EqualsRefOfLiteral(a.Scale, b.Scale) && + EqualsSliceOfString(a.EnumValues, b.EnumValues) +} + +// EqualsRefOfColumnTypeOptions does deep equals between the two objects. +func EqualsRefOfColumnTypeOptions(a, b *ColumnTypeOptions) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.NotNull == b.NotNull && + a.Autoincrement == b.Autoincrement && + EqualsExpr(a.Default, b.Default) && + EqualsExpr(a.OnUpdate, b.OnUpdate) && + EqualsRefOfLiteral(a.Comment, b.Comment) && + a.KeyOpt == b.KeyOpt +} + +// EqualsRefOfCommit does deep equals between the two objects. +func EqualsRefOfCommit(a, b *Commit) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true +} + +// EqualsRefOfComparisonExpr does deep equals between the two objects. +func EqualsRefOfComparisonExpr(a, b *ComparisonExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Operator == b.Operator && + EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.Right, b.Right) && + EqualsExpr(a.Escape, b.Escape) +} + +// EqualsRefOfConstraintDefinition does deep equals between the two objects. +func EqualsRefOfConstraintDefinition(a, b *ConstraintDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Name == b.Name && + EqualsConstraintInfo(a.Details, b.Details) +} + +// EqualsRefOfConvertExpr does deep equals between the two objects. +func EqualsRefOfConvertExpr(a, b *ConvertExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Expr, b.Expr) && + EqualsRefOfConvertType(a.Type, b.Type) +} + +// EqualsRefOfConvertType does deep equals between the two objects. +func EqualsRefOfConvertType(a, b *ConvertType) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && + a.Charset == b.Charset && + EqualsRefOfLiteral(a.Length, b.Length) && + EqualsRefOfLiteral(a.Scale, b.Scale) && + a.Operator == b.Operator +} + +// EqualsRefOfConvertUsingExpr does deep equals between the two objects. +func EqualsRefOfConvertUsingExpr(a, b *ConvertUsingExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfCreateDatabase does deep equals between the two objects. +func EqualsRefOfCreateDatabase(a, b *CreateDatabase) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.DBName == b.DBName && + a.IfNotExists == b.IfNotExists && + a.FullyParsed == b.FullyParsed && + EqualsComments(a.Comments, b.Comments) && + EqualsSliceOfCollateAndCharset(a.CreateOptions, b.CreateOptions) +} + +// EqualsRefOfCreateTable does deep equals between the two objects. +func EqualsRefOfCreateTable(a, b *CreateTable) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Temp == b.Temp && + a.IfNotExists == b.IfNotExists && + a.FullyParsed == b.FullyParsed && + EqualsTableName(a.Table, b.Table) && + EqualsRefOfTableSpec(a.TableSpec, b.TableSpec) && + EqualsRefOfOptLike(a.OptLike, b.OptLike) +} + +// EqualsRefOfCreateView does deep equals between the two objects. +func EqualsRefOfCreateView(a, b *CreateView) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Algorithm == b.Algorithm && + a.Definer == b.Definer && + a.Security == b.Security && + a.CheckOption == b.CheckOption && + a.IsReplace == b.IsReplace && + EqualsTableName(a.ViewName, b.ViewName) && + EqualsColumns(a.Columns, b.Columns) && + EqualsSelectStatement(a.Select, b.Select) +} + +// EqualsRefOfCurTimeFuncExpr does deep equals between the two objects. +func EqualsRefOfCurTimeFuncExpr(a, b *CurTimeFuncExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Name, b.Name) && + EqualsExpr(a.Fsp, b.Fsp) +} + +// EqualsRefOfDefault does deep equals between the two objects. +func EqualsRefOfDefault(a, b *Default) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.ColName == b.ColName +} + +// EqualsRefOfDelete does deep equals between the two objects. +func EqualsRefOfDelete(a, b *Delete) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Ignore == b.Ignore && + EqualsComments(a.Comments, b.Comments) && + EqualsTableNames(a.Targets, b.Targets) && + EqualsTableExprs(a.TableExprs, b.TableExprs) && + EqualsPartitions(a.Partitions, b.Partitions) && + EqualsRefOfWhere(a.Where, b.Where) && + EqualsOrderBy(a.OrderBy, b.OrderBy) && + EqualsRefOfLimit(a.Limit, b.Limit) +} + +// EqualsRefOfDerivedTable does deep equals between the two objects. +func EqualsRefOfDerivedTable(a, b *DerivedTable) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSelectStatement(a.Select, b.Select) +} + +// EqualsRefOfDropColumn does deep equals between the two objects. +func EqualsRefOfDropColumn(a, b *DropColumn) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfColName(a.Name, b.Name) +} + +// EqualsRefOfDropDatabase does deep equals between the two objects. +func EqualsRefOfDropDatabase(a, b *DropDatabase) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.DBName == b.DBName && + a.IfExists == b.IfExists && + EqualsComments(a.Comments, b.Comments) +} + +// EqualsRefOfDropKey does deep equals between the two objects. +func EqualsRefOfDropKey(a, b *DropKey) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Name == b.Name && + a.Type == b.Type +} + +// EqualsRefOfDropTable does deep equals between the two objects. +func EqualsRefOfDropTable(a, b *DropTable) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Temp == b.Temp && + a.IfExists == b.IfExists && + EqualsTableNames(a.FromTables, b.FromTables) +} + +// EqualsRefOfDropView does deep equals between the two objects. +func EqualsRefOfDropView(a, b *DropView) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.IfExists == b.IfExists && + EqualsTableNames(a.FromTables, b.FromTables) +} + +// EqualsRefOfExistsExpr does deep equals between the two objects. +func EqualsRefOfExistsExpr(a, b *ExistsExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfSubquery(a.Subquery, b.Subquery) +} + +// EqualsRefOfExplainStmt does deep equals between the two objects. +func EqualsRefOfExplainStmt(a, b *ExplainStmt) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && + EqualsStatement(a.Statement, b.Statement) +} + +// EqualsRefOfExplainTab does deep equals between the two objects. +func EqualsRefOfExplainTab(a, b *ExplainTab) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Wild == b.Wild && + EqualsTableName(a.Table, b.Table) +} + +// EqualsRefOfFlush does deep equals between the two objects. +func EqualsRefOfFlush(a, b *Flush) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.IsLocal == b.IsLocal && + a.WithLock == b.WithLock && + a.ForExport == b.ForExport && + EqualsSliceOfString(a.FlushOptions, b.FlushOptions) && + EqualsTableNames(a.TableNames, b.TableNames) +} + +// EqualsRefOfForce does deep equals between the two objects. +func EqualsRefOfForce(a, b *Force) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true +} + +// EqualsRefOfForeignKeyDefinition does deep equals between the two objects. +func EqualsRefOfForeignKeyDefinition(a, b *ForeignKeyDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColumns(a.Source, b.Source) && + EqualsTableName(a.ReferencedTable, b.ReferencedTable) && + EqualsColumns(a.ReferencedColumns, b.ReferencedColumns) && + a.OnDelete == b.OnDelete && + a.OnUpdate == b.OnUpdate +} + +// EqualsRefOfFuncExpr does deep equals between the two objects. +func EqualsRefOfFuncExpr(a, b *FuncExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Distinct == b.Distinct && + EqualsTableIdent(a.Qualifier, b.Qualifier) && + EqualsColIdent(a.Name, b.Name) && + EqualsSelectExprs(a.Exprs, b.Exprs) +} + +// EqualsRefOfGroupConcatExpr does deep equals between the two objects. +func EqualsRefOfGroupConcatExpr(a, b *GroupConcatExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Distinct == b.Distinct && + a.Separator == b.Separator && + EqualsSelectExprs(a.Exprs, b.Exprs) && + EqualsOrderBy(a.OrderBy, b.OrderBy) && + EqualsRefOfLimit(a.Limit, b.Limit) +} + +// EqualsRefOfIndexColumn does deep equals between the two objects. +func EqualsRefOfIndexColumn(a, b *IndexColumn) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Column, b.Column) && + EqualsRefOfLiteral(a.Length, b.Length) && + a.Direction == b.Direction +} + +// EqualsRefOfIndexDefinition does deep equals between the two objects. +func EqualsRefOfIndexDefinition(a, b *IndexDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfIndexInfo(a.Info, b.Info) && + EqualsSliceOfRefOfIndexColumn(a.Columns, b.Columns) && + EqualsSliceOfRefOfIndexOption(a.Options, b.Options) +} + +// EqualsRefOfIndexHints does deep equals between the two objects. +func EqualsRefOfIndexHints(a, b *IndexHints) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && + EqualsSliceOfColIdent(a.Indexes, b.Indexes) +} + +// EqualsRefOfIndexInfo does deep equals between the two objects. +func EqualsRefOfIndexInfo(a, b *IndexInfo) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && + a.Primary == b.Primary && + a.Spatial == b.Spatial && + a.Fulltext == b.Fulltext && + a.Unique == b.Unique && + EqualsColIdent(a.Name, b.Name) && + EqualsColIdent(a.ConstraintName, b.ConstraintName) +} + +// EqualsRefOfIndexOption does deep equals between the two objects. +func EqualsRefOfIndexOption(a, b *IndexOption) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Name == b.Name && + a.String == b.String && + EqualsRefOfLiteral(a.Value, b.Value) +} + +// EqualsRefOfInsert does deep equals between the two objects. +func EqualsRefOfInsert(a, b *Insert) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Action == b.Action && + EqualsComments(a.Comments, b.Comments) && + a.Ignore == b.Ignore && + EqualsTableName(a.Table, b.Table) && + EqualsPartitions(a.Partitions, b.Partitions) && + EqualsColumns(a.Columns, b.Columns) && + EqualsInsertRows(a.Rows, b.Rows) && + EqualsOnDup(a.OnDup, b.OnDup) +} + +// EqualsRefOfIntervalExpr does deep equals between the two objects. +func EqualsRefOfIntervalExpr(a, b *IntervalExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Unit == b.Unit && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfIsExpr does deep equals between the two objects. +func EqualsRefOfIsExpr(a, b *IsExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Operator == b.Operator && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfJoinCondition does deep equals between the two objects. +func EqualsRefOfJoinCondition(a, b *JoinCondition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.On, b.On) && + EqualsColumns(a.Using, b.Using) +} + +// EqualsRefOfJoinTableExpr does deep equals between the two objects. +func EqualsRefOfJoinTableExpr(a, b *JoinTableExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableExpr(a.LeftExpr, b.LeftExpr) && + a.Join == b.Join && + EqualsTableExpr(a.RightExpr, b.RightExpr) && + EqualsJoinCondition(a.Condition, b.Condition) +} + +// EqualsRefOfKeyState does deep equals between the two objects. +func EqualsRefOfKeyState(a, b *KeyState) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Enable == b.Enable +} + +// EqualsRefOfLimit does deep equals between the two objects. +func EqualsRefOfLimit(a, b *Limit) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Offset, b.Offset) && + EqualsExpr(a.Rowcount, b.Rowcount) +} + +// EqualsRefOfLiteral does deep equals between the two objects. +func EqualsRefOfLiteral(a, b *Literal) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Val == b.Val && + a.Type == b.Type +} + +// EqualsRefOfLoad does deep equals between the two objects. +func EqualsRefOfLoad(a, b *Load) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true +} + +// EqualsRefOfLockOption does deep equals between the two objects. +func EqualsRefOfLockOption(a, b *LockOption) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type +} + +// EqualsRefOfLockTables does deep equals between the two objects. +func EqualsRefOfLockTables(a, b *LockTables) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableAndLockTypes(a.Tables, b.Tables) +} + +// EqualsRefOfMatchExpr does deep equals between the two objects. +func EqualsRefOfMatchExpr(a, b *MatchExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSelectExprs(a.Columns, b.Columns) && + EqualsExpr(a.Expr, b.Expr) && + a.Option == b.Option +} + +// EqualsRefOfModifyColumn does deep equals between the two objects. +func EqualsRefOfModifyColumn(a, b *ModifyColumn) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfColumnDefinition(a.NewColDefinition, b.NewColDefinition) && + EqualsRefOfColName(a.First, b.First) && + EqualsRefOfColName(a.After, b.After) +} + +// EqualsRefOfNextval does deep equals between the two objects. +func EqualsRefOfNextval(a, b *Nextval) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfNotExpr does deep equals between the two objects. +func EqualsRefOfNotExpr(a, b *NotExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfNullVal does deep equals between the two objects. +func EqualsRefOfNullVal(a, b *NullVal) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true +} + +// EqualsRefOfOptLike does deep equals between the two objects. +func EqualsRefOfOptLike(a, b *OptLike) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableName(a.LikeTable, b.LikeTable) +} + +// EqualsRefOfOrExpr does deep equals between the two objects. +func EqualsRefOfOrExpr(a, b *OrExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.Right, b.Right) +} + +// EqualsRefOfOrder does deep equals between the two objects. +func EqualsRefOfOrder(a, b *Order) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Expr, b.Expr) && + a.Direction == b.Direction +} + +// EqualsRefOfOrderByOption does deep equals between the two objects. +func EqualsRefOfOrderByOption(a, b *OrderByOption) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColumns(a.Cols, b.Cols) +} + +// EqualsRefOfOtherAdmin does deep equals between the two objects. +func EqualsRefOfOtherAdmin(a, b *OtherAdmin) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true +} + +// EqualsRefOfOtherRead does deep equals between the two objects. +func EqualsRefOfOtherRead(a, b *OtherRead) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true +} + +// EqualsRefOfParenSelect does deep equals between the two objects. +func EqualsRefOfParenSelect(a, b *ParenSelect) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSelectStatement(a.Select, b.Select) +} + +// EqualsRefOfParenTableExpr does deep equals between the two objects. +func EqualsRefOfParenTableExpr(a, b *ParenTableExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableExprs(a.Exprs, b.Exprs) +} + +// EqualsRefOfPartitionDefinition does deep equals between the two objects. +func EqualsRefOfPartitionDefinition(a, b *PartitionDefinition) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Maxvalue == b.Maxvalue && + EqualsColIdent(a.Name, b.Name) && + EqualsExpr(a.Limit, b.Limit) +} + +// EqualsRefOfPartitionSpec does deep equals between the two objects. +func EqualsRefOfPartitionSpec(a, b *PartitionSpec) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.IsAll == b.IsAll && + a.WithoutValidation == b.WithoutValidation && + a.Action == b.Action && + EqualsPartitions(a.Names, b.Names) && + EqualsRefOfLiteral(a.Number, b.Number) && + EqualsTableName(a.TableName, b.TableName) && + EqualsSliceOfRefOfPartitionDefinition(a.Definitions, b.Definitions) +} + +// EqualsRefOfRangeCond does deep equals between the two objects. +func EqualsRefOfRangeCond(a, b *RangeCond) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Operator == b.Operator && + EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.From, b.From) && + EqualsExpr(a.To, b.To) +} + +// EqualsRefOfRelease does deep equals between the two objects. +func EqualsRefOfRelease(a, b *Release) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Name, b.Name) +} + +// EqualsRefOfRenameIndex does deep equals between the two objects. +func EqualsRefOfRenameIndex(a, b *RenameIndex) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.OldName == b.OldName && + a.NewName == b.NewName +} + +// EqualsRefOfRenameTable does deep equals between the two objects. +func EqualsRefOfRenameTable(a, b *RenameTable) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSliceOfRefOfRenameTablePair(a.TablePairs, b.TablePairs) +} + +// EqualsRefOfRenameTableName does deep equals between the two objects. +func EqualsRefOfRenameTableName(a, b *RenameTableName) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableName(a.Table, b.Table) +} + +// EqualsRefOfRenameTablePair does deep equals between the two objects. +func EqualsRefOfRenameTablePair(a, b *RenameTablePair) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableName(a.FromTable, b.FromTable) && + EqualsTableName(a.ToTable, b.ToTable) +} + +// EqualsRefOfRevertMigration does deep equals between the two objects. +func EqualsRefOfRevertMigration(a, b *RevertMigration) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.UUID == b.UUID +} + +// EqualsRefOfRollback does deep equals between the two objects. +func EqualsRefOfRollback(a, b *Rollback) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true +} + +// EqualsRefOfSRollback does deep equals between the two objects. +func EqualsRefOfSRollback(a, b *SRollback) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Name, b.Name) +} + +// EqualsRefOfSavepoint does deep equals between the two objects. +func EqualsRefOfSavepoint(a, b *Savepoint) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Name, b.Name) +} + +// EqualsRefOfSelect does deep equals between the two objects. +func EqualsRefOfSelect(a, b *Select) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Distinct == b.Distinct && + a.StraightJoinHint == b.StraightJoinHint && + a.SQLCalcFoundRows == b.SQLCalcFoundRows && + EqualsRefOfBool(a.Cache, b.Cache) && + EqualsComments(a.Comments, b.Comments) && + EqualsSelectExprs(a.SelectExprs, b.SelectExprs) && + EqualsTableExprs(a.From, b.From) && + EqualsRefOfWhere(a.Where, b.Where) && + EqualsGroupBy(a.GroupBy, b.GroupBy) && + EqualsRefOfWhere(a.Having, b.Having) && + EqualsOrderBy(a.OrderBy, b.OrderBy) && + EqualsRefOfLimit(a.Limit, b.Limit) && + a.Lock == b.Lock && + EqualsRefOfSelectInto(a.Into, b.Into) +} + +// EqualsRefOfSelectInto does deep equals between the two objects. +func EqualsRefOfSelectInto(a, b *SelectInto) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.FileName == b.FileName && + a.Charset == b.Charset && + a.FormatOption == b.FormatOption && + a.ExportOption == b.ExportOption && + a.Manifest == b.Manifest && + a.Overwrite == b.Overwrite && + a.Type == b.Type +} + +// EqualsRefOfSet does deep equals between the two objects. +func EqualsRefOfSet(a, b *Set) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsComments(a.Comments, b.Comments) && + EqualsSetExprs(a.Exprs, b.Exprs) +} + +// EqualsRefOfSetExpr does deep equals between the two objects. +func EqualsRefOfSetExpr(a, b *SetExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Scope == b.Scope && + EqualsColIdent(a.Name, b.Name) && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfSetTransaction does deep equals between the two objects. +func EqualsRefOfSetTransaction(a, b *SetTransaction) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSQLNode(a.SQLNode, b.SQLNode) && + EqualsComments(a.Comments, b.Comments) && + a.Scope == b.Scope && + EqualsSliceOfCharacteristic(a.Characteristics, b.Characteristics) +} + +// EqualsRefOfShow does deep equals between the two objects. +func EqualsRefOfShow(a, b *Show) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsShowInternal(a.Internal, b.Internal) +} + +// EqualsRefOfShowBasic does deep equals between the two objects. +func EqualsRefOfShowBasic(a, b *ShowBasic) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Full == b.Full && + a.DbName == b.DbName && + a.Command == b.Command && + EqualsTableName(a.Tbl, b.Tbl) && + EqualsRefOfShowFilter(a.Filter, b.Filter) +} + +// EqualsRefOfShowCreate does deep equals between the two objects. +func EqualsRefOfShowCreate(a, b *ShowCreate) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Command == b.Command && + EqualsTableName(a.Op, b.Op) +} + +// EqualsRefOfShowFilter does deep equals between the two objects. +func EqualsRefOfShowFilter(a, b *ShowFilter) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Like == b.Like && + EqualsExpr(a.Filter, b.Filter) +} + +// EqualsRefOfShowLegacy does deep equals between the two objects. +func EqualsRefOfShowLegacy(a, b *ShowLegacy) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Extended == b.Extended && + a.Type == b.Type && + EqualsTableName(a.OnTable, b.OnTable) && + EqualsTableName(a.Table, b.Table) && + EqualsRefOfShowTablesOpt(a.ShowTablesOpt, b.ShowTablesOpt) && + a.Scope == b.Scope && + EqualsExpr(a.ShowCollationFilterOpt, b.ShowCollationFilterOpt) +} + +// EqualsRefOfShowTablesOpt does deep equals between the two objects. +func EqualsRefOfShowTablesOpt(a, b *ShowTablesOpt) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Full == b.Full && + a.DbName == b.DbName && + EqualsRefOfShowFilter(a.Filter, b.Filter) +} + +// EqualsRefOfStarExpr does deep equals between the two objects. +func EqualsRefOfStarExpr(a, b *StarExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableName(a.TableName, b.TableName) +} + +// EqualsRefOfStream does deep equals between the two objects. +func EqualsRefOfStream(a, b *Stream) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsComments(a.Comments, b.Comments) && + EqualsSelectExpr(a.SelectExpr, b.SelectExpr) && + EqualsTableName(a.Table, b.Table) +} + +// EqualsRefOfSubquery does deep equals between the two objects. +func EqualsRefOfSubquery(a, b *Subquery) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSelectStatement(a.Select, b.Select) +} + +// EqualsRefOfSubstrExpr does deep equals between the two objects. +func EqualsRefOfSubstrExpr(a, b *SubstrExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfColName(a.Name, b.Name) && + EqualsRefOfLiteral(a.StrVal, b.StrVal) && + EqualsExpr(a.From, b.From) && + EqualsExpr(a.To, b.To) +} + +// EqualsRefOfTableAndLockType does deep equals between the two objects. +func EqualsRefOfTableAndLockType(a, b *TableAndLockType) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableExpr(a.Table, b.Table) && + a.Lock == b.Lock +} + +// EqualsRefOfTableIdent does deep equals between the two objects. +func EqualsRefOfTableIdent(a, b *TableIdent) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.v == b.v +} + +// EqualsRefOfTableName does deep equals between the two objects. +func EqualsRefOfTableName(a, b *TableName) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableIdent(a.Name, b.Name) && + EqualsTableIdent(a.Qualifier, b.Qualifier) +} + +// EqualsRefOfTableOption does deep equals between the two objects. +func EqualsRefOfTableOption(a, b *TableOption) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Name == b.Name && + a.String == b.String && + EqualsRefOfLiteral(a.Value, b.Value) && + EqualsTableNames(a.Tables, b.Tables) +} + +// EqualsRefOfTableSpec does deep equals between the two objects. +func EqualsRefOfTableSpec(a, b *TableSpec) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSliceOfRefOfColumnDefinition(a.Columns, b.Columns) && + EqualsSliceOfRefOfIndexDefinition(a.Indexes, b.Indexes) && + EqualsSliceOfRefOfConstraintDefinition(a.Constraints, b.Constraints) && + EqualsTableOptions(a.Options, b.Options) +} + +// EqualsRefOfTablespaceOperation does deep equals between the two objects. +func EqualsRefOfTablespaceOperation(a, b *TablespaceOperation) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Import == b.Import +} + +// EqualsRefOfTimestampFuncExpr does deep equals between the two objects. +func EqualsRefOfTimestampFuncExpr(a, b *TimestampFuncExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Name == b.Name && + a.Unit == b.Unit && + EqualsExpr(a.Expr1, b.Expr1) && + EqualsExpr(a.Expr2, b.Expr2) +} + +// EqualsRefOfTruncateTable does deep equals between the two objects. +func EqualsRefOfTruncateTable(a, b *TruncateTable) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableName(a.Table, b.Table) +} + +// EqualsRefOfUnaryExpr does deep equals between the two objects. +func EqualsRefOfUnaryExpr(a, b *UnaryExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Operator == b.Operator && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfUnion does deep equals between the two objects. +func EqualsRefOfUnion(a, b *Union) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsSelectStatement(a.FirstStatement, b.FirstStatement) && + EqualsSliceOfRefOfUnionSelect(a.UnionSelects, b.UnionSelects) && + EqualsOrderBy(a.OrderBy, b.OrderBy) && + EqualsRefOfLimit(a.Limit, b.Limit) && + a.Lock == b.Lock +} + +// EqualsRefOfUnionSelect does deep equals between the two objects. +func EqualsRefOfUnionSelect(a, b *UnionSelect) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Distinct == b.Distinct && + EqualsSelectStatement(a.Statement, b.Statement) +} + +// EqualsRefOfUnlockTables does deep equals between the two objects. +func EqualsRefOfUnlockTables(a, b *UnlockTables) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return true +} + +// EqualsRefOfUpdate does deep equals between the two objects. +func EqualsRefOfUpdate(a, b *Update) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsComments(a.Comments, b.Comments) && + a.Ignore == b.Ignore && + EqualsTableExprs(a.TableExprs, b.TableExprs) && + EqualsUpdateExprs(a.Exprs, b.Exprs) && + EqualsRefOfWhere(a.Where, b.Where) && + EqualsOrderBy(a.OrderBy, b.OrderBy) && + EqualsRefOfLimit(a.Limit, b.Limit) +} + +// EqualsRefOfUpdateExpr does deep equals between the two objects. +func EqualsRefOfUpdateExpr(a, b *UpdateExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfColName(a.Name, b.Name) && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfUse does deep equals between the two objects. +func EqualsRefOfUse(a, b *Use) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsTableIdent(a.DBName, b.DBName) +} + +// EqualsRefOfVStream does deep equals between the two objects. +func EqualsRefOfVStream(a, b *VStream) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsComments(a.Comments, b.Comments) && + EqualsSelectExpr(a.SelectExpr, b.SelectExpr) && + EqualsTableName(a.Table, b.Table) && + EqualsRefOfWhere(a.Where, b.Where) && + EqualsRefOfLimit(a.Limit, b.Limit) +} + +// EqualsRefOfValidation does deep equals between the two objects. +func EqualsRefOfValidation(a, b *Validation) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.With == b.With +} + +// EqualsRefOfValuesFuncExpr does deep equals between the two objects. +func EqualsRefOfValuesFuncExpr(a, b *ValuesFuncExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsRefOfColName(a.Name, b.Name) +} + +// EqualsRefOfVindexParam does deep equals between the two objects. +func EqualsRefOfVindexParam(a, b *VindexParam) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Val == b.Val && + EqualsColIdent(a.Key, b.Key) +} + +// EqualsRefOfVindexSpec does deep equals between the two objects. +func EqualsRefOfVindexSpec(a, b *VindexSpec) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsColIdent(a.Name, b.Name) && + EqualsColIdent(a.Type, b.Type) && + EqualsSliceOfVindexParam(a.Params, b.Params) +} + +// EqualsRefOfWhen does deep equals between the two objects. +func EqualsRefOfWhen(a, b *When) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Cond, b.Cond) && + EqualsExpr(a.Val, b.Val) +} + +// EqualsRefOfWhere does deep equals between the two objects. +func EqualsRefOfWhere(a, b *Where) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return a.Type == b.Type && + EqualsExpr(a.Expr, b.Expr) +} + +// EqualsRefOfXorExpr does deep equals between the two objects. +func EqualsRefOfXorExpr(a, b *XorExpr) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return EqualsExpr(a.Left, b.Left) && + EqualsExpr(a.Right, b.Right) +} + +// EqualsSQLNode does deep equals between the two objects. +func EqualsSQLNode(inA, inB SQLNode) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case AccessMode: + b, ok := inB.(AccessMode) + if !ok { + return false + } + return a == b + case *AddColumns: + b, ok := inB.(*AddColumns) + if !ok { + return false + } + return EqualsRefOfAddColumns(a, b) + case *AddConstraintDefinition: + b, ok := inB.(*AddConstraintDefinition) + if !ok { + return false + } + return EqualsRefOfAddConstraintDefinition(a, b) + case *AddIndexDefinition: + b, ok := inB.(*AddIndexDefinition) + if !ok { + return false + } + return EqualsRefOfAddIndexDefinition(a, b) + case AlgorithmValue: + b, ok := inB.(AlgorithmValue) + if !ok { + return false + } + return a == b + case *AliasedExpr: + b, ok := inB.(*AliasedExpr) + if !ok { + return false + } + return EqualsRefOfAliasedExpr(a, b) + case *AliasedTableExpr: + b, ok := inB.(*AliasedTableExpr) + if !ok { + return false + } + return EqualsRefOfAliasedTableExpr(a, b) + case *AlterCharset: + b, ok := inB.(*AlterCharset) + if !ok { + return false + } + return EqualsRefOfAlterCharset(a, b) + case *AlterColumn: + b, ok := inB.(*AlterColumn) + if !ok { + return false + } + return EqualsRefOfAlterColumn(a, b) + case *AlterDatabase: + b, ok := inB.(*AlterDatabase) + if !ok { + return false + } + return EqualsRefOfAlterDatabase(a, b) + case *AlterMigration: + b, ok := inB.(*AlterMigration) + if !ok { + return false + } + return EqualsRefOfAlterMigration(a, b) + case *AlterTable: + b, ok := inB.(*AlterTable) + if !ok { + return false + } + return EqualsRefOfAlterTable(a, b) + case *AlterView: + b, ok := inB.(*AlterView) + if !ok { + return false + } + return EqualsRefOfAlterView(a, b) + case *AlterVschema: + b, ok := inB.(*AlterVschema) + if !ok { + return false + } + return EqualsRefOfAlterVschema(a, b) + case *AndExpr: + b, ok := inB.(*AndExpr) + if !ok { + return false + } + return EqualsRefOfAndExpr(a, b) + case Argument: + b, ok := inB.(Argument) + if !ok { + return false + } + return a == b + case *AutoIncSpec: + b, ok := inB.(*AutoIncSpec) + if !ok { + return false + } + return EqualsRefOfAutoIncSpec(a, b) + case *Begin: + b, ok := inB.(*Begin) + if !ok { + return false + } + return EqualsRefOfBegin(a, b) + case *BinaryExpr: + b, ok := inB.(*BinaryExpr) + if !ok { + return false + } + return EqualsRefOfBinaryExpr(a, b) + case BoolVal: + b, ok := inB.(BoolVal) + if !ok { + return false + } + return a == b + case *CallProc: + b, ok := inB.(*CallProc) + if !ok { + return false + } + return EqualsRefOfCallProc(a, b) + case *CaseExpr: + b, ok := inB.(*CaseExpr) + if !ok { + return false + } + return EqualsRefOfCaseExpr(a, b) + case *ChangeColumn: + b, ok := inB.(*ChangeColumn) + if !ok { + return false + } + return EqualsRefOfChangeColumn(a, b) + case *CheckConstraintDefinition: + b, ok := inB.(*CheckConstraintDefinition) + if !ok { + return false + } + return EqualsRefOfCheckConstraintDefinition(a, b) + case ColIdent: + b, ok := inB.(ColIdent) + if !ok { + return false + } + return EqualsColIdent(a, b) + case *ColName: + b, ok := inB.(*ColName) + if !ok { + return false + } + return EqualsRefOfColName(a, b) + case *CollateExpr: + b, ok := inB.(*CollateExpr) + if !ok { + return false + } + return EqualsRefOfCollateExpr(a, b) + case *ColumnDefinition: + b, ok := inB.(*ColumnDefinition) + if !ok { + return false + } + return EqualsRefOfColumnDefinition(a, b) + case *ColumnType: + b, ok := inB.(*ColumnType) + if !ok { + return false + } + return EqualsRefOfColumnType(a, b) + case Columns: + b, ok := inB.(Columns) + if !ok { + return false + } + return EqualsColumns(a, b) + case Comments: + b, ok := inB.(Comments) + if !ok { + return false + } + return EqualsComments(a, b) + case *Commit: + b, ok := inB.(*Commit) + if !ok { + return false + } + return EqualsRefOfCommit(a, b) + case *ComparisonExpr: + b, ok := inB.(*ComparisonExpr) + if !ok { + return false + } + return EqualsRefOfComparisonExpr(a, b) + case *ConstraintDefinition: + b, ok := inB.(*ConstraintDefinition) + if !ok { + return false + } + return EqualsRefOfConstraintDefinition(a, b) + case *ConvertExpr: + b, ok := inB.(*ConvertExpr) + if !ok { + return false + } + return EqualsRefOfConvertExpr(a, b) + case *ConvertType: + b, ok := inB.(*ConvertType) + if !ok { + return false + } + return EqualsRefOfConvertType(a, b) + case *ConvertUsingExpr: + b, ok := inB.(*ConvertUsingExpr) + if !ok { + return false + } + return EqualsRefOfConvertUsingExpr(a, b) + case *CreateDatabase: + b, ok := inB.(*CreateDatabase) + if !ok { + return false + } + return EqualsRefOfCreateDatabase(a, b) + case *CreateTable: + b, ok := inB.(*CreateTable) + if !ok { + return false + } + return EqualsRefOfCreateTable(a, b) + case *CreateView: + b, ok := inB.(*CreateView) + if !ok { + return false + } + return EqualsRefOfCreateView(a, b) + case *CurTimeFuncExpr: + b, ok := inB.(*CurTimeFuncExpr) + if !ok { + return false + } + return EqualsRefOfCurTimeFuncExpr(a, b) + case *Default: + b, ok := inB.(*Default) + if !ok { + return false + } + return EqualsRefOfDefault(a, b) + case *Delete: + b, ok := inB.(*Delete) + if !ok { + return false + } + return EqualsRefOfDelete(a, b) + case *DerivedTable: + b, ok := inB.(*DerivedTable) + if !ok { + return false + } + return EqualsRefOfDerivedTable(a, b) + case *DropColumn: + b, ok := inB.(*DropColumn) + if !ok { + return false + } + return EqualsRefOfDropColumn(a, b) + case *DropDatabase: + b, ok := inB.(*DropDatabase) + if !ok { + return false + } + return EqualsRefOfDropDatabase(a, b) + case *DropKey: + b, ok := inB.(*DropKey) + if !ok { + return false + } + return EqualsRefOfDropKey(a, b) + case *DropTable: + b, ok := inB.(*DropTable) + if !ok { + return false + } + return EqualsRefOfDropTable(a, b) + case *DropView: + b, ok := inB.(*DropView) + if !ok { + return false + } + return EqualsRefOfDropView(a, b) + case *ExistsExpr: + b, ok := inB.(*ExistsExpr) + if !ok { + return false + } + return EqualsRefOfExistsExpr(a, b) + case *ExplainStmt: + b, ok := inB.(*ExplainStmt) + if !ok { + return false + } + return EqualsRefOfExplainStmt(a, b) + case *ExplainTab: + b, ok := inB.(*ExplainTab) + if !ok { + return false + } + return EqualsRefOfExplainTab(a, b) + case Exprs: + b, ok := inB.(Exprs) + if !ok { + return false + } + return EqualsExprs(a, b) + case *Flush: + b, ok := inB.(*Flush) + if !ok { + return false + } + return EqualsRefOfFlush(a, b) + case *Force: + b, ok := inB.(*Force) + if !ok { + return false + } + return EqualsRefOfForce(a, b) + case *ForeignKeyDefinition: + b, ok := inB.(*ForeignKeyDefinition) + if !ok { + return false + } + return EqualsRefOfForeignKeyDefinition(a, b) + case *FuncExpr: + b, ok := inB.(*FuncExpr) + if !ok { + return false + } + return EqualsRefOfFuncExpr(a, b) + case GroupBy: + b, ok := inB.(GroupBy) + if !ok { + return false + } + return EqualsGroupBy(a, b) + case *GroupConcatExpr: + b, ok := inB.(*GroupConcatExpr) + if !ok { + return false + } + return EqualsRefOfGroupConcatExpr(a, b) + case *IndexDefinition: + b, ok := inB.(*IndexDefinition) + if !ok { + return false + } + return EqualsRefOfIndexDefinition(a, b) + case *IndexHints: + b, ok := inB.(*IndexHints) + if !ok { + return false + } + return EqualsRefOfIndexHints(a, b) + case *IndexInfo: + b, ok := inB.(*IndexInfo) + if !ok { + return false + } + return EqualsRefOfIndexInfo(a, b) + case *Insert: + b, ok := inB.(*Insert) + if !ok { + return false + } + return EqualsRefOfInsert(a, b) + case *IntervalExpr: + b, ok := inB.(*IntervalExpr) + if !ok { + return false + } + return EqualsRefOfIntervalExpr(a, b) + case *IsExpr: + b, ok := inB.(*IsExpr) + if !ok { + return false + } + return EqualsRefOfIsExpr(a, b) + case IsolationLevel: + b, ok := inB.(IsolationLevel) + if !ok { + return false + } + return a == b + case JoinCondition: + b, ok := inB.(JoinCondition) + if !ok { + return false + } + return EqualsJoinCondition(a, b) + case *JoinTableExpr: + b, ok := inB.(*JoinTableExpr) + if !ok { + return false + } + return EqualsRefOfJoinTableExpr(a, b) + case *KeyState: + b, ok := inB.(*KeyState) + if !ok { + return false + } + return EqualsRefOfKeyState(a, b) + case *Limit: + b, ok := inB.(*Limit) + if !ok { + return false + } + return EqualsRefOfLimit(a, b) + case ListArg: + b, ok := inB.(ListArg) + if !ok { + return false + } + return EqualsListArg(a, b) + case *Literal: + b, ok := inB.(*Literal) + if !ok { + return false + } + return EqualsRefOfLiteral(a, b) + case *Load: + b, ok := inB.(*Load) + if !ok { + return false + } + return EqualsRefOfLoad(a, b) + case *LockOption: + b, ok := inB.(*LockOption) + if !ok { + return false + } + return EqualsRefOfLockOption(a, b) + case *LockTables: + b, ok := inB.(*LockTables) + if !ok { + return false + } + return EqualsRefOfLockTables(a, b) + case *MatchExpr: + b, ok := inB.(*MatchExpr) + if !ok { + return false + } + return EqualsRefOfMatchExpr(a, b) + case *ModifyColumn: + b, ok := inB.(*ModifyColumn) + if !ok { + return false + } + return EqualsRefOfModifyColumn(a, b) + case *Nextval: + b, ok := inB.(*Nextval) + if !ok { + return false + } + return EqualsRefOfNextval(a, b) + case *NotExpr: + b, ok := inB.(*NotExpr) + if !ok { + return false + } + return EqualsRefOfNotExpr(a, b) + case *NullVal: + b, ok := inB.(*NullVal) + if !ok { + return false + } + return EqualsRefOfNullVal(a, b) + case OnDup: + b, ok := inB.(OnDup) + if !ok { + return false + } + return EqualsOnDup(a, b) + case *OptLike: + b, ok := inB.(*OptLike) + if !ok { + return false + } + return EqualsRefOfOptLike(a, b) + case *OrExpr: + b, ok := inB.(*OrExpr) + if !ok { + return false + } + return EqualsRefOfOrExpr(a, b) + case *Order: + b, ok := inB.(*Order) + if !ok { + return false + } + return EqualsRefOfOrder(a, b) + case OrderBy: + b, ok := inB.(OrderBy) + if !ok { + return false + } + return EqualsOrderBy(a, b) + case *OrderByOption: + b, ok := inB.(*OrderByOption) + if !ok { + return false + } + return EqualsRefOfOrderByOption(a, b) + case *OtherAdmin: + b, ok := inB.(*OtherAdmin) + if !ok { + return false + } + return EqualsRefOfOtherAdmin(a, b) + case *OtherRead: + b, ok := inB.(*OtherRead) + if !ok { + return false + } + return EqualsRefOfOtherRead(a, b) + case *ParenSelect: + b, ok := inB.(*ParenSelect) + if !ok { + return false + } + return EqualsRefOfParenSelect(a, b) + case *ParenTableExpr: + b, ok := inB.(*ParenTableExpr) + if !ok { + return false + } + return EqualsRefOfParenTableExpr(a, b) + case *PartitionDefinition: + b, ok := inB.(*PartitionDefinition) + if !ok { + return false + } + return EqualsRefOfPartitionDefinition(a, b) + case *PartitionSpec: + b, ok := inB.(*PartitionSpec) + if !ok { + return false + } + return EqualsRefOfPartitionSpec(a, b) + case Partitions: + b, ok := inB.(Partitions) + if !ok { + return false + } + return EqualsPartitions(a, b) + case *RangeCond: + b, ok := inB.(*RangeCond) + if !ok { + return false + } + return EqualsRefOfRangeCond(a, b) + case ReferenceAction: + b, ok := inB.(ReferenceAction) + if !ok { + return false + } + return a == b + case *Release: + b, ok := inB.(*Release) + if !ok { + return false + } + return EqualsRefOfRelease(a, b) + case *RenameIndex: + b, ok := inB.(*RenameIndex) + if !ok { + return false + } + return EqualsRefOfRenameIndex(a, b) + case *RenameTable: + b, ok := inB.(*RenameTable) + if !ok { + return false + } + return EqualsRefOfRenameTable(a, b) + case *RenameTableName: + b, ok := inB.(*RenameTableName) + if !ok { + return false + } + return EqualsRefOfRenameTableName(a, b) + case *RevertMigration: + b, ok := inB.(*RevertMigration) + if !ok { + return false + } + return EqualsRefOfRevertMigration(a, b) + case *Rollback: + b, ok := inB.(*Rollback) + if !ok { + return false + } + return EqualsRefOfRollback(a, b) + case *SRollback: + b, ok := inB.(*SRollback) + if !ok { + return false + } + return EqualsRefOfSRollback(a, b) + case *Savepoint: + b, ok := inB.(*Savepoint) + if !ok { + return false + } + return EqualsRefOfSavepoint(a, b) + case *Select: + b, ok := inB.(*Select) + if !ok { + return false + } + return EqualsRefOfSelect(a, b) + case SelectExprs: + b, ok := inB.(SelectExprs) + if !ok { + return false + } + return EqualsSelectExprs(a, b) + case *SelectInto: + b, ok := inB.(*SelectInto) + if !ok { + return false + } + return EqualsRefOfSelectInto(a, b) + case *Set: + b, ok := inB.(*Set) + if !ok { + return false + } + return EqualsRefOfSet(a, b) + case *SetExpr: + b, ok := inB.(*SetExpr) + if !ok { + return false + } + return EqualsRefOfSetExpr(a, b) + case SetExprs: + b, ok := inB.(SetExprs) + if !ok { + return false + } + return EqualsSetExprs(a, b) + case *SetTransaction: + b, ok := inB.(*SetTransaction) + if !ok { + return false + } + return EqualsRefOfSetTransaction(a, b) + case *Show: + b, ok := inB.(*Show) + if !ok { + return false + } + return EqualsRefOfShow(a, b) + case *ShowBasic: + b, ok := inB.(*ShowBasic) + if !ok { + return false + } + return EqualsRefOfShowBasic(a, b) + case *ShowCreate: + b, ok := inB.(*ShowCreate) + if !ok { + return false + } + return EqualsRefOfShowCreate(a, b) + case *ShowFilter: + b, ok := inB.(*ShowFilter) + if !ok { + return false + } + return EqualsRefOfShowFilter(a, b) + case *ShowLegacy: + b, ok := inB.(*ShowLegacy) + if !ok { + return false + } + return EqualsRefOfShowLegacy(a, b) + case *StarExpr: + b, ok := inB.(*StarExpr) + if !ok { + return false + } + return EqualsRefOfStarExpr(a, b) + case *Stream: + b, ok := inB.(*Stream) + if !ok { + return false + } + return EqualsRefOfStream(a, b) + case *Subquery: + b, ok := inB.(*Subquery) + if !ok { + return false + } + return EqualsRefOfSubquery(a, b) + case *SubstrExpr: + b, ok := inB.(*SubstrExpr) + if !ok { + return false + } + return EqualsRefOfSubstrExpr(a, b) + case TableExprs: + b, ok := inB.(TableExprs) + if !ok { + return false + } + return EqualsTableExprs(a, b) + case TableIdent: + b, ok := inB.(TableIdent) + if !ok { + return false + } + return EqualsTableIdent(a, b) + case TableName: + b, ok := inB.(TableName) + if !ok { + return false + } + return EqualsTableName(a, b) + case TableNames: + b, ok := inB.(TableNames) + if !ok { + return false + } + return EqualsTableNames(a, b) + case TableOptions: + b, ok := inB.(TableOptions) + if !ok { + return false + } + return EqualsTableOptions(a, b) + case *TableSpec: + b, ok := inB.(*TableSpec) + if !ok { + return false + } + return EqualsRefOfTableSpec(a, b) + case *TablespaceOperation: + b, ok := inB.(*TablespaceOperation) + if !ok { + return false + } + return EqualsRefOfTablespaceOperation(a, b) + case *TimestampFuncExpr: + b, ok := inB.(*TimestampFuncExpr) + if !ok { + return false + } + return EqualsRefOfTimestampFuncExpr(a, b) + case *TruncateTable: + b, ok := inB.(*TruncateTable) + if !ok { + return false + } + return EqualsRefOfTruncateTable(a, b) + case *UnaryExpr: + b, ok := inB.(*UnaryExpr) + if !ok { + return false + } + return EqualsRefOfUnaryExpr(a, b) + case *Union: + b, ok := inB.(*Union) + if !ok { + return false + } + return EqualsRefOfUnion(a, b) + case *UnionSelect: + b, ok := inB.(*UnionSelect) + if !ok { + return false + } + return EqualsRefOfUnionSelect(a, b) + case *UnlockTables: + b, ok := inB.(*UnlockTables) + if !ok { + return false + } + return EqualsRefOfUnlockTables(a, b) + case *Update: + b, ok := inB.(*Update) + if !ok { + return false + } + return EqualsRefOfUpdate(a, b) + case *UpdateExpr: + b, ok := inB.(*UpdateExpr) + if !ok { + return false + } + return EqualsRefOfUpdateExpr(a, b) + case UpdateExprs: + b, ok := inB.(UpdateExprs) + if !ok { + return false + } + return EqualsUpdateExprs(a, b) + case *Use: + b, ok := inB.(*Use) + if !ok { + return false + } + return EqualsRefOfUse(a, b) + case *VStream: + b, ok := inB.(*VStream) + if !ok { + return false + } + return EqualsRefOfVStream(a, b) + case ValTuple: + b, ok := inB.(ValTuple) + if !ok { + return false + } + return EqualsValTuple(a, b) + case *Validation: + b, ok := inB.(*Validation) + if !ok { + return false + } + return EqualsRefOfValidation(a, b) + case Values: + b, ok := inB.(Values) + if !ok { + return false + } + return EqualsValues(a, b) + case *ValuesFuncExpr: + b, ok := inB.(*ValuesFuncExpr) + if !ok { + return false + } + return EqualsRefOfValuesFuncExpr(a, b) + case VindexParam: + b, ok := inB.(VindexParam) + if !ok { + return false + } + return EqualsVindexParam(a, b) + case *VindexSpec: + b, ok := inB.(*VindexSpec) + if !ok { + return false + } + return EqualsRefOfVindexSpec(a, b) + case *When: + b, ok := inB.(*When) + if !ok { + return false + } + return EqualsRefOfWhen(a, b) + case *Where: + b, ok := inB.(*Where) + if !ok { + return false + } + return EqualsRefOfWhere(a, b) + case *XorExpr: + b, ok := inB.(*XorExpr) + if !ok { + return false + } + return EqualsRefOfXorExpr(a, b) + default: + // this should never happen + return false + } +} + +// EqualsSelectExpr does deep equals between the two objects. +func EqualsSelectExpr(inA, inB SelectExpr) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *AliasedExpr: + b, ok := inB.(*AliasedExpr) + if !ok { + return false + } + return EqualsRefOfAliasedExpr(a, b) + case *Nextval: + b, ok := inB.(*Nextval) + if !ok { + return false + } + return EqualsRefOfNextval(a, b) + case *StarExpr: + b, ok := inB.(*StarExpr) + if !ok { + return false + } + return EqualsRefOfStarExpr(a, b) + default: + // this should never happen + return false + } +} + +// EqualsSelectExprs does deep equals between the two objects. +func EqualsSelectExprs(a, b SelectExprs) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsSelectExpr(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsSelectStatement does deep equals between the two objects. +func EqualsSelectStatement(inA, inB SelectStatement) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *ParenSelect: + b, ok := inB.(*ParenSelect) + if !ok { + return false + } + return EqualsRefOfParenSelect(a, b) + case *Select: + b, ok := inB.(*Select) + if !ok { + return false + } + return EqualsRefOfSelect(a, b) + case *Union: + b, ok := inB.(*Union) + if !ok { + return false + } + return EqualsRefOfUnion(a, b) + default: + // this should never happen + return false + } +} + +// EqualsSetExprs does deep equals between the two objects. +func EqualsSetExprs(a, b SetExprs) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfSetExpr(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsShowInternal does deep equals between the two objects. +func EqualsShowInternal(inA, inB ShowInternal) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *ShowBasic: + b, ok := inB.(*ShowBasic) + if !ok { + return false + } + return EqualsRefOfShowBasic(a, b) + case *ShowCreate: + b, ok := inB.(*ShowCreate) + if !ok { + return false + } + return EqualsRefOfShowCreate(a, b) + case *ShowLegacy: + b, ok := inB.(*ShowLegacy) + if !ok { + return false + } + return EqualsRefOfShowLegacy(a, b) + default: + // this should never happen + return false + } +} + +// EqualsSimpleTableExpr does deep equals between the two objects. +func EqualsSimpleTableExpr(inA, inB SimpleTableExpr) bool { + if inA == nil && inB == nil { + return true } - if cont, err := f(in); err != nil || !cont { - return err + if inA == nil || inB == nil { + return false } - if err := VisitColumns(in.Source, f); err != nil { - return err + switch a := inA.(type) { + case *DerivedTable: + b, ok := inB.(*DerivedTable) + if !ok { + return false + } + return EqualsRefOfDerivedTable(a, b) + case TableName: + b, ok := inB.(TableName) + if !ok { + return false + } + return EqualsTableName(a, b) + default: + // this should never happen + return false } - if err := VisitTableName(in.ReferencedTable, f); err != nil { - return err +} + +// EqualsSliceOfAlterOption does deep equals between the two objects. +func EqualsSliceOfAlterOption(a, b []AlterOption) bool { + if len(a) != len(b) { + return false } - if err := VisitColumns(in.ReferencedColumns, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if !EqualsAlterOption(a[i], b[i]) { + return false + } } - if err := VisitReferenceAction(in.OnDelete, f); err != nil { - return err + return true +} + +// EqualsSliceOfCharacteristic does deep equals between the two objects. +func EqualsSliceOfCharacteristic(a, b []Characteristic) bool { + if len(a) != len(b) { + return false } - if err := VisitReferenceAction(in.OnUpdate, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if !EqualsCharacteristic(a[i], b[i]) { + return false + } } - return nil + return true } -// EqualsRefOfFuncExpr does deep equals between the two objects. -func EqualsRefOfFuncExpr(a, b *FuncExpr) bool { - if a == b { - return true +// EqualsSliceOfColIdent does deep equals between the two objects. +func EqualsSliceOfColIdent(a, b []ColIdent) bool { + if len(a) != len(b) { + return false } - if a == nil || b == nil { + for i := 0; i < len(a); i++ { + if !EqualsColIdent(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsSliceOfCollateAndCharset does deep equals between the two objects. +func EqualsSliceOfCollateAndCharset(a, b []CollateAndCharset) bool { + if len(a) != len(b) { return false } - return a.Distinct == b.Distinct && - EqualsTableIdent(a.Qualifier, b.Qualifier) && - EqualsColIdent(a.Name, b.Name) && - EqualsSelectExprs(a.Exprs, b.Exprs) + for i := 0; i < len(a); i++ { + if !EqualsCollateAndCharset(a[i], b[i]) { + return false + } + } + return true } -// CloneRefOfFuncExpr creates a deep clone of the input. -func CloneRefOfFuncExpr(n *FuncExpr) *FuncExpr { - if n == nil { - return nil +// EqualsSliceOfRefOfColumnDefinition does deep equals between the two objects. +func EqualsSliceOfRefOfColumnDefinition(a, b []*ColumnDefinition) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfColumnDefinition(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsSliceOfRefOfConstraintDefinition does deep equals between the two objects. +func EqualsSliceOfRefOfConstraintDefinition(a, b []*ConstraintDefinition) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfConstraintDefinition(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsSliceOfRefOfIndexColumn does deep equals between the two objects. +func EqualsSliceOfRefOfIndexColumn(a, b []*IndexColumn) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfIndexColumn(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsSliceOfRefOfIndexDefinition does deep equals between the two objects. +func EqualsSliceOfRefOfIndexDefinition(a, b []*IndexDefinition) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfIndexDefinition(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsSliceOfRefOfIndexOption does deep equals between the two objects. +func EqualsSliceOfRefOfIndexOption(a, b []*IndexOption) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfIndexOption(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsSliceOfRefOfPartitionDefinition does deep equals between the two objects. +func EqualsSliceOfRefOfPartitionDefinition(a, b []*PartitionDefinition) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfPartitionDefinition(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsSliceOfRefOfRenameTablePair does deep equals between the two objects. +func EqualsSliceOfRefOfRenameTablePair(a, b []*RenameTablePair) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfRenameTablePair(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsSliceOfRefOfUnionSelect does deep equals between the two objects. +func EqualsSliceOfRefOfUnionSelect(a, b []*UnionSelect) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfUnionSelect(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsSliceOfRefOfWhen does deep equals between the two objects. +func EqualsSliceOfRefOfWhen(a, b []*When) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsRefOfWhen(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsSliceOfString does deep equals between the two objects. +func EqualsSliceOfString(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + return true +} + +// EqualsSliceOfVindexParam does deep equals between the two objects. +func EqualsSliceOfVindexParam(a, b []VindexParam) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !EqualsVindexParam(a[i], b[i]) { + return false + } + } + return true +} + +// EqualsStatement does deep equals between the two objects. +func EqualsStatement(inA, inB Statement) bool { + if inA == nil && inB == nil { + return true + } + if inA == nil || inB == nil { + return false + } + switch a := inA.(type) { + case *AlterDatabase: + b, ok := inB.(*AlterDatabase) + if !ok { + return false + } + return EqualsRefOfAlterDatabase(a, b) + case *AlterMigration: + b, ok := inB.(*AlterMigration) + if !ok { + return false + } + return EqualsRefOfAlterMigration(a, b) + case *AlterTable: + b, ok := inB.(*AlterTable) + if !ok { + return false + } + return EqualsRefOfAlterTable(a, b) + case *AlterView: + b, ok := inB.(*AlterView) + if !ok { + return false + } + return EqualsRefOfAlterView(a, b) + case *AlterVschema: + b, ok := inB.(*AlterVschema) + if !ok { + return false + } + return EqualsRefOfAlterVschema(a, b) + case *Begin: + b, ok := inB.(*Begin) + if !ok { + return false + } + return EqualsRefOfBegin(a, b) + case *CallProc: + b, ok := inB.(*CallProc) + if !ok { + return false + } + return EqualsRefOfCallProc(a, b) + case *Commit: + b, ok := inB.(*Commit) + if !ok { + return false + } + return EqualsRefOfCommit(a, b) + case *CreateDatabase: + b, ok := inB.(*CreateDatabase) + if !ok { + return false + } + return EqualsRefOfCreateDatabase(a, b) + case *CreateTable: + b, ok := inB.(*CreateTable) + if !ok { + return false + } + return EqualsRefOfCreateTable(a, b) + case *CreateView: + b, ok := inB.(*CreateView) + if !ok { + return false + } + return EqualsRefOfCreateView(a, b) + case *Delete: + b, ok := inB.(*Delete) + if !ok { + return false + } + return EqualsRefOfDelete(a, b) + case *DropDatabase: + b, ok := inB.(*DropDatabase) + if !ok { + return false + } + return EqualsRefOfDropDatabase(a, b) + case *DropTable: + b, ok := inB.(*DropTable) + if !ok { + return false + } + return EqualsRefOfDropTable(a, b) + case *DropView: + b, ok := inB.(*DropView) + if !ok { + return false + } + return EqualsRefOfDropView(a, b) + case *ExplainStmt: + b, ok := inB.(*ExplainStmt) + if !ok { + return false + } + return EqualsRefOfExplainStmt(a, b) + case *ExplainTab: + b, ok := inB.(*ExplainTab) + if !ok { + return false + } + return EqualsRefOfExplainTab(a, b) + case *Flush: + b, ok := inB.(*Flush) + if !ok { + return false + } + return EqualsRefOfFlush(a, b) + case *Insert: + b, ok := inB.(*Insert) + if !ok { + return false + } + return EqualsRefOfInsert(a, b) + case *Load: + b, ok := inB.(*Load) + if !ok { + return false + } + return EqualsRefOfLoad(a, b) + case *LockTables: + b, ok := inB.(*LockTables) + if !ok { + return false + } + return EqualsRefOfLockTables(a, b) + case *OtherAdmin: + b, ok := inB.(*OtherAdmin) + if !ok { + return false + } + return EqualsRefOfOtherAdmin(a, b) + case *OtherRead: + b, ok := inB.(*OtherRead) + if !ok { + return false + } + return EqualsRefOfOtherRead(a, b) + case *ParenSelect: + b, ok := inB.(*ParenSelect) + if !ok { + return false + } + return EqualsRefOfParenSelect(a, b) + case *Release: + b, ok := inB.(*Release) + if !ok { + return false + } + return EqualsRefOfRelease(a, b) + case *RenameTable: + b, ok := inB.(*RenameTable) + if !ok { + return false + } + return EqualsRefOfRenameTable(a, b) + case *RevertMigration: + b, ok := inB.(*RevertMigration) + if !ok { + return false + } + return EqualsRefOfRevertMigration(a, b) + case *Rollback: + b, ok := inB.(*Rollback) + if !ok { + return false + } + return EqualsRefOfRollback(a, b) + case *SRollback: + b, ok := inB.(*SRollback) + if !ok { + return false + } + return EqualsRefOfSRollback(a, b) + case *Savepoint: + b, ok := inB.(*Savepoint) + if !ok { + return false + } + return EqualsRefOfSavepoint(a, b) + case *Select: + b, ok := inB.(*Select) + if !ok { + return false + } + return EqualsRefOfSelect(a, b) + case *Set: + b, ok := inB.(*Set) + if !ok { + return false + } + return EqualsRefOfSet(a, b) + case *SetTransaction: + b, ok := inB.(*SetTransaction) + if !ok { + return false + } + return EqualsRefOfSetTransaction(a, b) + case *Show: + b, ok := inB.(*Show) + if !ok { + return false + } + return EqualsRefOfShow(a, b) + case *Stream: + b, ok := inB.(*Stream) + if !ok { + return false + } + return EqualsRefOfStream(a, b) + case *TruncateTable: + b, ok := inB.(*TruncateTable) + if !ok { + return false + } + return EqualsRefOfTruncateTable(a, b) + case *Union: + b, ok := inB.(*Union) + if !ok { + return false + } + return EqualsRefOfUnion(a, b) + case *UnlockTables: + b, ok := inB.(*UnlockTables) + if !ok { + return false + } + return EqualsRefOfUnlockTables(a, b) + case *Update: + b, ok := inB.(*Update) + if !ok { + return false + } + return EqualsRefOfUpdate(a, b) + case *Use: + b, ok := inB.(*Use) + if !ok { + return false + } + return EqualsRefOfUse(a, b) + case *VStream: + b, ok := inB.(*VStream) + if !ok { + return false + } + return EqualsRefOfVStream(a, b) + default: + // this should never happen + return false } - out := *n - out.Qualifier = CloneTableIdent(n.Qualifier) - out.Name = CloneColIdent(n.Name) - out.Exprs = CloneSelectExprs(n.Exprs) - return &out } -// VisitRefOfFuncExpr will visit all parts of the AST -func VisitRefOfFuncExpr(in *FuncExpr, f Visit) error { - if in == nil { - return nil +// EqualsTableAndLockTypes does deep equals between the two objects. +func EqualsTableAndLockTypes(a, b TableAndLockTypes) bool { + if len(a) != len(b) { + return false } - if cont, err := f(in); err != nil || !cont { - return err + for i := 0; i < len(a); i++ { + if !EqualsRefOfTableAndLockType(a[i], b[i]) { + return false + } } - if err := VisitTableIdent(in.Qualifier, f); err != nil { - return err + return true +} + +// EqualsTableExpr does deep equals between the two objects. +func EqualsTableExpr(inA, inB TableExpr) bool { + if inA == nil && inB == nil { + return true } - if err := VisitColIdent(in.Name, f); err != nil { - return err + if inA == nil || inB == nil { + return false } - if err := VisitSelectExprs(in.Exprs, f); err != nil { - return err + switch a := inA.(type) { + case *AliasedTableExpr: + b, ok := inB.(*AliasedTableExpr) + if !ok { + return false + } + return EqualsRefOfAliasedTableExpr(a, b) + case *JoinTableExpr: + b, ok := inB.(*JoinTableExpr) + if !ok { + return false + } + return EqualsRefOfJoinTableExpr(a, b) + case *ParenTableExpr: + b, ok := inB.(*ParenTableExpr) + if !ok { + return false + } + return EqualsRefOfParenTableExpr(a, b) + default: + // this should never happen + return false } - return nil } -// EqualsGroupBy does deep equals between the two objects. -func EqualsGroupBy(a, b GroupBy) bool { +// EqualsTableExprs does deep equals between the two objects. +func EqualsTableExprs(a, b TableExprs) bool { if len(a) != len(b) { return false } for i := 0; i < len(a); i++ { - if !EqualsExpr(a[i], b[i]) { + if !EqualsTableExpr(a[i], b[i]) { return false } } return true } -// CloneGroupBy creates a deep clone of the input. -func CloneGroupBy(n GroupBy) GroupBy { - res := make(GroupBy, 0, len(n)) - for _, x := range n { - res = append(res, CloneExpr(x)) - } - return res +// EqualsTableIdent does deep equals between the two objects. +func EqualsTableIdent(a, b TableIdent) bool { + return a.v == b.v } -// VisitGroupBy will visit all parts of the AST -func VisitGroupBy(in GroupBy, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsTableName does deep equals between the two objects. +func EqualsTableName(a, b TableName) bool { + return EqualsTableIdent(a.Name, b.Name) && + EqualsTableIdent(a.Qualifier, b.Qualifier) +} + +// EqualsTableNames does deep equals between the two objects. +func EqualsTableNames(a, b TableNames) bool { + if len(a) != len(b) { + return false } - for _, el := range in { - if err := VisitExpr(el, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if !EqualsTableName(a[i], b[i]) { + return false } } - return nil + return true } -// EqualsRefOfGroupConcatExpr does deep equals between the two objects. -func EqualsRefOfGroupConcatExpr(a, b *GroupConcatExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { +// EqualsTableOptions does deep equals between the two objects. +func EqualsTableOptions(a, b TableOptions) bool { + if len(a) != len(b) { return false } - return a.Distinct == b.Distinct && - a.Separator == b.Separator && - EqualsSelectExprs(a.Exprs, b.Exprs) && - EqualsOrderBy(a.OrderBy, b.OrderBy) && - EqualsRefOfLimit(a.Limit, b.Limit) -} - -// CloneRefOfGroupConcatExpr creates a deep clone of the input. -func CloneRefOfGroupConcatExpr(n *GroupConcatExpr) *GroupConcatExpr { - if n == nil { - return nil + for i := 0; i < len(a); i++ { + if !EqualsRefOfTableOption(a[i], b[i]) { + return false + } } - out := *n - out.Exprs = CloneSelectExprs(n.Exprs) - out.OrderBy = CloneOrderBy(n.OrderBy) - out.Limit = CloneRefOfLimit(n.Limit) - return &out + return true } -// VisitRefOfGroupConcatExpr will visit all parts of the AST -func VisitRefOfGroupConcatExpr(in *GroupConcatExpr, f Visit) error { - if in == nil { - return nil - } - if cont, err := f(in); err != nil || !cont { - return err +// EqualsUpdateExprs does deep equals between the two objects. +func EqualsUpdateExprs(a, b UpdateExprs) bool { + if len(a) != len(b) { + return false } - if err := VisitSelectExprs(in.Exprs, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if !EqualsRefOfUpdateExpr(a[i], b[i]) { + return false + } } - if err := VisitOrderBy(in.OrderBy, f); err != nil { - return err + return true +} + +// EqualsValTuple does deep equals between the two objects. +func EqualsValTuple(a, b ValTuple) bool { + if len(a) != len(b) { + return false } - if err := VisitRefOfLimit(in.Limit, f); err != nil { - return err + for i := 0; i < len(a); i++ { + if !EqualsExpr(a[i], b[i]) { + return false + } } - return nil + return true } -// EqualsRefOfIndexDefinition does deep equals between the two objects. -func EqualsRefOfIndexDefinition(a, b *IndexDefinition) bool { - if a == b { - return true - } - if a == nil || b == nil { +// EqualsValues does deep equals between the two objects. +func EqualsValues(a, b Values) bool { + if len(a) != len(b) { return false } - return EqualsRefOfIndexInfo(a.Info, b.Info) && - EqualsSliceOfRefOfIndexColumn(a.Columns, b.Columns) && - EqualsSliceOfRefOfIndexOption(a.Options, b.Options) + for i := 0; i < len(a); i++ { + if !EqualsValTuple(a[i], b[i]) { + return false + } + } + return true } -// CloneRefOfIndexDefinition creates a deep clone of the input. -func CloneRefOfIndexDefinition(n *IndexDefinition) *IndexDefinition { - if n == nil { +// EqualsVindexParam does deep equals between the two objects. +func EqualsVindexParam(a, b VindexParam) bool { + return a.Val == b.Val && + EqualsColIdent(a.Key, b.Key) +} +func VisitAccessMode(in AccessMode, f Visit) error { + _, err := f(in) + return err +} +func VisitAlgorithmValue(in AlgorithmValue, f Visit) error { + _, err := f(in) + return err +} +func VisitAlterOption(in AlterOption, f Visit) error { + if in == nil { + return nil + } + switch in := in.(type) { + case *AddColumns: + return VisitRefOfAddColumns(in, f) + case *AddConstraintDefinition: + return VisitRefOfAddConstraintDefinition(in, f) + case *AddIndexDefinition: + return VisitRefOfAddIndexDefinition(in, f) + case AlgorithmValue: + return VisitAlgorithmValue(in, f) + case *AlterCharset: + return VisitRefOfAlterCharset(in, f) + case *AlterColumn: + return VisitRefOfAlterColumn(in, f) + case *ChangeColumn: + return VisitRefOfChangeColumn(in, f) + case *DropColumn: + return VisitRefOfDropColumn(in, f) + case *DropKey: + return VisitRefOfDropKey(in, f) + case *Force: + return VisitRefOfForce(in, f) + case *KeyState: + return VisitRefOfKeyState(in, f) + case *LockOption: + return VisitRefOfLockOption(in, f) + case *ModifyColumn: + return VisitRefOfModifyColumn(in, f) + case *OrderByOption: + return VisitRefOfOrderByOption(in, f) + case *RenameIndex: + return VisitRefOfRenameIndex(in, f) + case *RenameTableName: + return VisitRefOfRenameTableName(in, f) + case TableOptions: + return VisitTableOptions(in, f) + case *TablespaceOperation: + return VisitRefOfTablespaceOperation(in, f) + case *Validation: + return VisitRefOfValidation(in, f) + default: + // this should never happen return nil } - out := *n - out.Info = CloneRefOfIndexInfo(n.Info) - out.Columns = CloneSliceOfRefOfIndexColumn(n.Columns) - out.Options = CloneSliceOfRefOfIndexOption(n.Options) - return &out } - -// VisitRefOfIndexDefinition will visit all parts of the AST -func VisitRefOfIndexDefinition(in *IndexDefinition, f Visit) error { +func VisitArgument(in Argument, f Visit) error { + _, err := f(in) + return err +} +func VisitBoolVal(in BoolVal, f Visit) error { + _, err := f(in) + return err +} +func VisitCharacteristic(in Characteristic, f Visit) error { if in == nil { return nil } + switch in := in.(type) { + case AccessMode: + return VisitAccessMode(in, f) + case IsolationLevel: + return VisitIsolationLevel(in, f) + default: + // this should never happen + return nil + } +} +func VisitColIdent(in ColIdent, f Visit) error { if cont, err := f(in); err != nil || !cont { return err } - if err := VisitRefOfIndexInfo(in.Info, f); err != nil { - return err - } return nil } - -// EqualsRefOfIndexHints does deep equals between the two objects. -func EqualsRefOfIndexHints(a, b *IndexHints) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +func VisitColTuple(in ColTuple, f Visit) error { + if in == nil { + return nil } - return a.Type == b.Type && - EqualsSliceOfColIdent(a.Indexes, b.Indexes) -} - -// CloneRefOfIndexHints creates a deep clone of the input. -func CloneRefOfIndexHints(n *IndexHints) *IndexHints { - if n == nil { + switch in := in.(type) { + case ListArg: + return VisitListArg(in, f) + case *Subquery: + return VisitRefOfSubquery(in, f) + case ValTuple: + return VisitValTuple(in, f) + default: + // this should never happen return nil } - out := *n - out.Indexes = CloneSliceOfColIdent(n.Indexes) - return &out } - -// VisitRefOfIndexHints will visit all parts of the AST -func VisitRefOfIndexHints(in *IndexHints, f Visit) error { +func VisitColumns(in Columns, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in.Indexes { + for _, el := range in { if err := VisitColIdent(el, f); err != nil { return err } } return nil } - -// EqualsRefOfIndexInfo does deep equals between the two objects. -func EqualsRefOfIndexInfo(a, b *IndexInfo) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Type == b.Type && - a.Primary == b.Primary && - a.Spatial == b.Spatial && - a.Fulltext == b.Fulltext && - a.Unique == b.Unique && - EqualsColIdent(a.Name, b.Name) && - EqualsColIdent(a.ConstraintName, b.ConstraintName) +func VisitComments(in Comments, f Visit) error { + _, err := f(in) + return err } - -// CloneRefOfIndexInfo creates a deep clone of the input. -func CloneRefOfIndexInfo(n *IndexInfo) *IndexInfo { - if n == nil { +func VisitConstraintInfo(in ConstraintInfo, f Visit) error { + if in == nil { + return nil + } + switch in := in.(type) { + case *CheckConstraintDefinition: + return VisitRefOfCheckConstraintDefinition(in, f) + case *ForeignKeyDefinition: + return VisitRefOfForeignKeyDefinition(in, f) + default: + // this should never happen return nil } - out := *n - out.Name = CloneColIdent(n.Name) - out.ConstraintName = CloneColIdent(n.ConstraintName) - return &out } - -// VisitRefOfIndexInfo will visit all parts of the AST -func VisitRefOfIndexInfo(in *IndexInfo, f Visit) error { +func VisitDBDDLStatement(in DBDDLStatement, f Visit) error { if in == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitColIdent(in.Name, f); err != nil { - return err - } - if err := VisitColIdent(in.ConstraintName, f); err != nil { - return err + switch in := in.(type) { + case *AlterDatabase: + return VisitRefOfAlterDatabase(in, f) + case *CreateDatabase: + return VisitRefOfCreateDatabase(in, f) + case *DropDatabase: + return VisitRefOfDropDatabase(in, f) + default: + // this should never happen + return nil } - return nil } - -// EqualsRefOfInsert does deep equals between the two objects. -func EqualsRefOfInsert(a, b *Insert) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +func VisitDDLStatement(in DDLStatement, f Visit) error { + if in == nil { + return nil } - return a.Action == b.Action && - EqualsComments(a.Comments, b.Comments) && - a.Ignore == b.Ignore && - EqualsTableName(a.Table, b.Table) && - EqualsPartitions(a.Partitions, b.Partitions) && - EqualsColumns(a.Columns, b.Columns) && - EqualsInsertRows(a.Rows, b.Rows) && - EqualsOnDup(a.OnDup, b.OnDup) -} - -// CloneRefOfInsert creates a deep clone of the input. -func CloneRefOfInsert(n *Insert) *Insert { - if n == nil { + switch in := in.(type) { + case *AlterTable: + return VisitRefOfAlterTable(in, f) + case *AlterView: + return VisitRefOfAlterView(in, f) + case *CreateTable: + return VisitRefOfCreateTable(in, f) + case *CreateView: + return VisitRefOfCreateView(in, f) + case *DropTable: + return VisitRefOfDropTable(in, f) + case *DropView: + return VisitRefOfDropView(in, f) + case *RenameTable: + return VisitRefOfRenameTable(in, f) + case *TruncateTable: + return VisitRefOfTruncateTable(in, f) + default: + // this should never happen return nil } - out := *n - out.Comments = CloneComments(n.Comments) - out.Table = CloneTableName(n.Table) - out.Partitions = ClonePartitions(n.Partitions) - out.Columns = CloneColumns(n.Columns) - out.Rows = CloneInsertRows(n.Rows) - out.OnDup = CloneOnDup(n.OnDup) - return &out } - -// VisitRefOfInsert will visit all parts of the AST -func VisitRefOfInsert(in *Insert, f Visit) error { +func VisitExplain(in Explain, f Visit) error { if in == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitComments(in.Comments, f); err != nil { - return err - } - if err := VisitTableName(in.Table, f); err != nil { - return err - } - if err := VisitPartitions(in.Partitions, f); err != nil { - return err - } - if err := VisitColumns(in.Columns, f); err != nil { - return err - } - if err := VisitInsertRows(in.Rows, f); err != nil { - return err - } - if err := VisitOnDup(in.OnDup, f); err != nil { - return err + switch in := in.(type) { + case *ExplainStmt: + return VisitRefOfExplainStmt(in, f) + case *ExplainTab: + return VisitRefOfExplainTab(in, f) + default: + // this should never happen + return nil } - return nil } - -// EqualsRefOfIntervalExpr does deep equals between the two objects. -func EqualsRefOfIntervalExpr(a, b *IntervalExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +func VisitExpr(in Expr, f Visit) error { + if in == nil { + return nil } - return a.Unit == b.Unit && - EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfIntervalExpr creates a deep clone of the input. -func CloneRefOfIntervalExpr(n *IntervalExpr) *IntervalExpr { - if n == nil { + switch in := in.(type) { + case *AndExpr: + return VisitRefOfAndExpr(in, f) + case Argument: + return VisitArgument(in, f) + case *BinaryExpr: + return VisitRefOfBinaryExpr(in, f) + case BoolVal: + return VisitBoolVal(in, f) + case *CaseExpr: + return VisitRefOfCaseExpr(in, f) + case *ColName: + return VisitRefOfColName(in, f) + case *CollateExpr: + return VisitRefOfCollateExpr(in, f) + case *ComparisonExpr: + return VisitRefOfComparisonExpr(in, f) + case *ConvertExpr: + return VisitRefOfConvertExpr(in, f) + case *ConvertUsingExpr: + return VisitRefOfConvertUsingExpr(in, f) + case *CurTimeFuncExpr: + return VisitRefOfCurTimeFuncExpr(in, f) + case *Default: + return VisitRefOfDefault(in, f) + case *ExistsExpr: + return VisitRefOfExistsExpr(in, f) + case *FuncExpr: + return VisitRefOfFuncExpr(in, f) + case *GroupConcatExpr: + return VisitRefOfGroupConcatExpr(in, f) + case *IntervalExpr: + return VisitRefOfIntervalExpr(in, f) + case *IsExpr: + return VisitRefOfIsExpr(in, f) + case ListArg: + return VisitListArg(in, f) + case *Literal: + return VisitRefOfLiteral(in, f) + case *MatchExpr: + return VisitRefOfMatchExpr(in, f) + case *NotExpr: + return VisitRefOfNotExpr(in, f) + case *NullVal: + return VisitRefOfNullVal(in, f) + case *OrExpr: + return VisitRefOfOrExpr(in, f) + case *RangeCond: + return VisitRefOfRangeCond(in, f) + case *Subquery: + return VisitRefOfSubquery(in, f) + case *SubstrExpr: + return VisitRefOfSubstrExpr(in, f) + case *TimestampFuncExpr: + return VisitRefOfTimestampFuncExpr(in, f) + case *UnaryExpr: + return VisitRefOfUnaryExpr(in, f) + case ValTuple: + return VisitValTuple(in, f) + case *ValuesFuncExpr: + return VisitRefOfValuesFuncExpr(in, f) + case *XorExpr: + return VisitRefOfXorExpr(in, f) + default: + // this should never happen return nil } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out } - -// VisitRefOfIntervalExpr will visit all parts of the AST -func VisitRefOfIntervalExpr(in *IntervalExpr, f Visit) error { +func VisitExprs(in Exprs, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Expr, f); err != nil { - return err + for _, el := range in { + if err := VisitExpr(el, f); err != nil { + return err + } } return nil } - -// EqualsRefOfIsExpr does deep equals between the two objects. -func EqualsRefOfIsExpr(a, b *IsExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Operator == b.Operator && - EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfIsExpr creates a deep clone of the input. -func CloneRefOfIsExpr(n *IsExpr) *IsExpr { - if n == nil { - return nil - } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out -} - -// VisitRefOfIsExpr will visit all parts of the AST -func VisitRefOfIsExpr(in *IsExpr, f Visit) error { +func VisitGroupBy(in GroupBy, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Expr, f); err != nil { - return err + for _, el := range in { + if err := VisitExpr(el, f); err != nil { + return err + } } return nil } - -// EqualsJoinCondition does deep equals between the two objects. -func EqualsJoinCondition(a, b JoinCondition) bool { - return EqualsExpr(a.On, b.On) && - EqualsColumns(a.Using, b.Using) +func VisitInsertRows(in InsertRows, f Visit) error { + if in == nil { + return nil + } + switch in := in.(type) { + case *ParenSelect: + return VisitRefOfParenSelect(in, f) + case *Select: + return VisitRefOfSelect(in, f) + case *Union: + return VisitRefOfUnion(in, f) + case Values: + return VisitValues(in, f) + default: + // this should never happen + return nil + } } - -// CloneJoinCondition creates a deep clone of the input. -func CloneJoinCondition(n JoinCondition) JoinCondition { - return *CloneRefOfJoinCondition(&n) +func VisitIsolationLevel(in IsolationLevel, f Visit) error { + _, err := f(in) + return err } - -// VisitJoinCondition will visit all parts of the AST func VisitJoinCondition(in JoinCondition, f Visit) error { if cont, err := f(in); err != nil || !cont { return err @@ -3977,204 +6881,133 @@ func VisitJoinCondition(in JoinCondition, f Visit) error { } return nil } - -// EqualsRefOfJoinTableExpr does deep equals between the two objects. -func EqualsRefOfJoinTableExpr(a, b *JoinTableExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsTableExpr(a.LeftExpr, b.LeftExpr) && - a.Join == b.Join && - EqualsTableExpr(a.RightExpr, b.RightExpr) && - EqualsJoinCondition(a.Condition, b.Condition) -} - -// CloneRefOfJoinTableExpr creates a deep clone of the input. -func CloneRefOfJoinTableExpr(n *JoinTableExpr) *JoinTableExpr { - if n == nil { - return nil - } - out := *n - out.LeftExpr = CloneTableExpr(n.LeftExpr) - out.RightExpr = CloneTableExpr(n.RightExpr) - out.Condition = CloneJoinCondition(n.Condition) - return &out +func VisitListArg(in ListArg, f Visit) error { + _, err := f(in) + return err } - -// VisitRefOfJoinTableExpr will visit all parts of the AST -func VisitRefOfJoinTableExpr(in *JoinTableExpr, f Visit) error { +func VisitOnDup(in OnDup, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableExpr(in.LeftExpr, f); err != nil { - return err - } - if err := VisitTableExpr(in.RightExpr, f); err != nil { - return err - } - if err := VisitJoinCondition(in.Condition, f); err != nil { - return err + for _, el := range in { + if err := VisitRefOfUpdateExpr(el, f); err != nil { + return err + } } return nil } - -// EqualsRefOfKeyState does deep equals between the two objects. -func EqualsRefOfKeyState(a, b *KeyState) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Enable == b.Enable -} - -// CloneRefOfKeyState creates a deep clone of the input. -func CloneRefOfKeyState(n *KeyState) *KeyState { - if n == nil { - return nil - } - out := *n - return &out -} - -// VisitRefOfKeyState will visit all parts of the AST -func VisitRefOfKeyState(in *KeyState, f Visit) error { +func VisitOrderBy(in OrderBy, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } + for _, el := range in { + if err := VisitRefOfOrder(el, f); err != nil { + return err + } + } return nil } - -// EqualsRefOfLimit does deep equals between the two objects. -func EqualsRefOfLimit(a, b *Limit) bool { - if a == b { - return true +func VisitPartitions(in Partitions, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsExpr(a.Offset, b.Offset) && - EqualsExpr(a.Rowcount, b.Rowcount) -} - -// CloneRefOfLimit creates a deep clone of the input. -func CloneRefOfLimit(n *Limit) *Limit { - if n == nil { - return nil + for _, el := range in { + if err := VisitColIdent(el, f); err != nil { + return err + } } - out := *n - out.Offset = CloneExpr(n.Offset) - out.Rowcount = CloneExpr(n.Rowcount) - return &out + return nil } - -// VisitRefOfLimit will visit all parts of the AST -func VisitRefOfLimit(in *Limit, f Visit) error { +func VisitRefOfAddColumns(in *AddColumns, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Offset, f); err != nil { + for _, el := range in.Columns { + if err := VisitRefOfColumnDefinition(el, f); err != nil { + return err + } + } + if err := VisitRefOfColName(in.First, f); err != nil { return err } - if err := VisitExpr(in.Rowcount, f); err != nil { + if err := VisitRefOfColName(in.After, f); err != nil { return err } return nil } - -// EqualsListArg does deep equals between the two objects. -func EqualsListArg(a, b ListArg) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false - } +func VisitRefOfAddConstraintDefinition(in *AddConstraintDefinition, f Visit) error { + if in == nil { + return nil } - return true -} - -// CloneListArg creates a deep clone of the input. -func CloneListArg(n ListArg) ListArg { - res := make(ListArg, 0, len(n)) - copy(res, n) - return res -} - -// VisitListArg will visit all parts of the AST -func VisitListArg(in ListArg, f Visit) error { - _, err := f(in) - return err -} - -// EqualsRefOfLiteral does deep equals between the two objects. -func EqualsRefOfLiteral(a, b *Literal) bool { - if a == b { - return true + if cont, err := f(in); err != nil || !cont { + return err } - if a == nil || b == nil { - return false + if err := VisitRefOfConstraintDefinition(in.ConstraintDefinition, f); err != nil { + return err } - return a.Val == b.Val && - a.Type == b.Type + return nil } - -// CloneRefOfLiteral creates a deep clone of the input. -func CloneRefOfLiteral(n *Literal) *Literal { - if n == nil { +func VisitRefOfAddIndexDefinition(in *AddIndexDefinition, f Visit) error { + if in == nil { return nil } - out := *n - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitRefOfIndexDefinition(in.IndexDefinition, f); err != nil { + return err + } + return nil } - -// VisitRefOfLiteral will visit all parts of the AST -func VisitRefOfLiteral(in *Literal, f Visit) error { +func VisitRefOfAliasedExpr(in *AliasedExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - return nil -} - -// EqualsRefOfLoad does deep equals between the two objects. -func EqualsRefOfLoad(a, b *Load) bool { - if a == b { - return true + if err := VisitExpr(in.Expr, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitColIdent(in.As, f); err != nil { + return err } - return true + return nil } - -// CloneRefOfLoad creates a deep clone of the input. -func CloneRefOfLoad(n *Load) *Load { - if n == nil { +func VisitRefOfAliasedTableExpr(in *AliasedTableExpr, f Visit) error { + if in == nil { return nil } - out := *n - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitSimpleTableExpr(in.Expr, f); err != nil { + return err + } + if err := VisitPartitions(in.Partitions, f); err != nil { + return err + } + if err := VisitTableIdent(in.As, f); err != nil { + return err + } + if err := VisitRefOfIndexHints(in.Hints, f); err != nil { + return err + } + return nil } - -// VisitRefOfLoad will visit all parts of the AST -func VisitRefOfLoad(in *Load, f Visit) error { +func VisitRefOfAlterCharset(in *AlterCharset, f Visit) error { if in == nil { return nil } @@ -4183,29 +7016,22 @@ func VisitRefOfLoad(in *Load, f Visit) error { } return nil } - -// EqualsRefOfLockOption does deep equals between the two objects. -func EqualsRefOfLockOption(a, b *LockOption) bool { - if a == b { - return true +func VisitRefOfAlterColumn(in *AlterColumn, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Type == b.Type -} - -// CloneRefOfLockOption creates a deep clone of the input. -func CloneRefOfLockOption(n *LockOption) *LockOption { - if n == nil { - return nil + if err := VisitRefOfColName(in.Column, f); err != nil { + return err } - out := *n - return &out + if err := VisitExpr(in.DefaultVal, f); err != nil { + return err + } + return nil } - -// VisitRefOfLockOption will visit all parts of the AST -func VisitRefOfLockOption(in *LockOption, f Visit) error { +func VisitRefOfAlterDatabase(in *AlterDatabase, f Visit) error { if in == nil { return nil } @@ -4214,183 +7040,187 @@ func VisitRefOfLockOption(in *LockOption, f Visit) error { } return nil } - -// EqualsRefOfLockTables does deep equals between the two objects. -func EqualsRefOfLockTables(a, b *LockTables) bool { - if a == b { - return true +func VisitRefOfAlterMigration(in *AlterMigration, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsTableAndLockTypes(a.Tables, b.Tables) + return nil } - -// CloneRefOfLockTables creates a deep clone of the input. -func CloneRefOfLockTables(n *LockTables) *LockTables { - if n == nil { +func VisitRefOfAlterTable(in *AlterTable, f Visit) error { + if in == nil { return nil } - out := *n - out.Tables = CloneTableAndLockTypes(n.Tables) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitTableName(in.Table, f); err != nil { + return err + } + for _, el := range in.AlterOptions { + if err := VisitAlterOption(el, f); err != nil { + return err + } + } + if err := VisitRefOfPartitionSpec(in.PartitionSpec, f); err != nil { + return err + } + return nil } - -// VisitRefOfLockTables will visit all parts of the AST -func VisitRefOfLockTables(in *LockTables, f Visit) error { +func VisitRefOfAlterView(in *AlterView, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - return nil -} - -// EqualsRefOfMatchExpr does deep equals between the two objects. -func EqualsRefOfMatchExpr(a, b *MatchExpr) bool { - if a == b { - return true + if err := VisitTableName(in.ViewName, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitColumns(in.Columns, f); err != nil { + return err } - return EqualsSelectExprs(a.Columns, b.Columns) && - EqualsExpr(a.Expr, b.Expr) && - a.Option == b.Option -} - -// CloneRefOfMatchExpr creates a deep clone of the input. -func CloneRefOfMatchExpr(n *MatchExpr) *MatchExpr { - if n == nil { - return nil + if err := VisitSelectStatement(in.Select, f); err != nil { + return err } - out := *n - out.Columns = CloneSelectExprs(n.Columns) - out.Expr = CloneExpr(n.Expr) - return &out + return nil } - -// VisitRefOfMatchExpr will visit all parts of the AST -func VisitRefOfMatchExpr(in *MatchExpr, f Visit) error { +func VisitRefOfAlterVschema(in *AlterVschema, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitSelectExprs(in.Columns, f); err != nil { + if err := VisitTableName(in.Table, f); err != nil { return err } - if err := VisitExpr(in.Expr, f); err != nil { + if err := VisitRefOfVindexSpec(in.VindexSpec, f); err != nil { + return err + } + for _, el := range in.VindexCols { + if err := VisitColIdent(el, f); err != nil { + return err + } + } + if err := VisitRefOfAutoIncSpec(in.AutoIncSpec, f); err != nil { return err } return nil } - -// EqualsRefOfModifyColumn does deep equals between the two objects. -func EqualsRefOfModifyColumn(a, b *ModifyColumn) bool { - if a == b { - return true +func VisitRefOfAndExpr(in *AndExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsRefOfColumnDefinition(a.NewColDefinition, b.NewColDefinition) && - EqualsRefOfColName(a.First, b.First) && - EqualsRefOfColName(a.After, b.After) + if err := VisitExpr(in.Left, f); err != nil { + return err + } + if err := VisitExpr(in.Right, f); err != nil { + return err + } + return nil } - -// CloneRefOfModifyColumn creates a deep clone of the input. -func CloneRefOfModifyColumn(n *ModifyColumn) *ModifyColumn { - if n == nil { +func VisitRefOfAutoIncSpec(in *AutoIncSpec, f Visit) error { + if in == nil { return nil } - out := *n - out.NewColDefinition = CloneRefOfColumnDefinition(n.NewColDefinition) - out.First = CloneRefOfColName(n.First) - out.After = CloneRefOfColName(n.After) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitColIdent(in.Column, f); err != nil { + return err + } + if err := VisitTableName(in.Sequence, f); err != nil { + return err + } + return nil } - -// VisitRefOfModifyColumn will visit all parts of the AST -func VisitRefOfModifyColumn(in *ModifyColumn, f Visit) error { +func VisitRefOfBegin(in *Begin, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitRefOfColumnDefinition(in.NewColDefinition, f); err != nil { + return nil +} +func VisitRefOfBinaryExpr(in *BinaryExpr, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitRefOfColName(in.First, f); err != nil { + if err := VisitExpr(in.Left, f); err != nil { return err } - if err := VisitRefOfColName(in.After, f); err != nil { + if err := VisitExpr(in.Right, f); err != nil { return err } return nil } - -// EqualsRefOfNextval does deep equals between the two objects. -func EqualsRefOfNextval(a, b *Nextval) bool { - if a == b { - return true +func VisitRefOfCallProc(in *CallProc, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsExpr(a.Expr, b.Expr) + if err := VisitTableName(in.Name, f); err != nil { + return err + } + if err := VisitExprs(in.Params, f); err != nil { + return err + } + return nil } - -// CloneRefOfNextval creates a deep clone of the input. -func CloneRefOfNextval(n *Nextval) *Nextval { - if n == nil { +func VisitRefOfCaseExpr(in *CaseExpr, f Visit) error { + if in == nil { return nil } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitExpr(in.Expr, f); err != nil { + return err + } + for _, el := range in.Whens { + if err := VisitRefOfWhen(el, f); err != nil { + return err + } + } + if err := VisitExpr(in.Else, f); err != nil { + return err + } + return nil } - -// VisitRefOfNextval will visit all parts of the AST -func VisitRefOfNextval(in *Nextval, f Visit) error { +func VisitRefOfChangeColumn(in *ChangeColumn, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Expr, f); err != nil { + if err := VisitRefOfColName(in.OldColumn, f); err != nil { return err } - return nil -} - -// EqualsRefOfNotExpr does deep equals between the two objects. -func EqualsRefOfNotExpr(a, b *NotExpr) bool { - if a == b { - return true + if err := VisitRefOfColumnDefinition(in.NewColDefinition, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitRefOfColName(in.First, f); err != nil { + return err } - return EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfNotExpr creates a deep clone of the input. -func CloneRefOfNotExpr(n *NotExpr) *NotExpr { - if n == nil { - return nil + if err := VisitRefOfColName(in.After, f); err != nil { + return err } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out + return nil } - -// VisitRefOfNotExpr will visit all parts of the AST -func VisitRefOfNotExpr(in *NotExpr, f Visit) error { +func VisitRefOfCheckConstraintDefinition(in *CheckConstraintDefinition, f Visit) error { if in == nil { return nil } @@ -4402,29 +7232,7 @@ func VisitRefOfNotExpr(in *NotExpr, f Visit) error { } return nil } - -// EqualsRefOfNullVal does deep equals between the two objects. -func EqualsRefOfNullVal(a, b *NullVal) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return true -} - -// CloneRefOfNullVal creates a deep clone of the input. -func CloneRefOfNullVal(n *NullVal) *NullVal { - if n == nil { - return nil - } - out := *n - return &out -} - -// VisitRefOfNullVal will visit all parts of the AST -func VisitRefOfNullVal(in *NullVal, f Visit) error { +func VisitRefOfColIdent(in *ColIdent, f Visit) error { if in == nil { return nil } @@ -4433,105 +7241,70 @@ func VisitRefOfNullVal(in *NullVal, f Visit) error { } return nil } - -// EqualsOnDup does deep equals between the two objects. -func EqualsOnDup(a, b OnDup) bool { - if len(a) != len(b) { - return false +func VisitRefOfColName(in *ColName, f Visit) error { + if in == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfUpdateExpr(a[i], b[i]) { - return false - } + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// CloneOnDup creates a deep clone of the input. -func CloneOnDup(n OnDup) OnDup { - res := make(OnDup, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfUpdateExpr(x)) + if err := VisitColIdent(in.Name, f); err != nil { + return err } - return res + if err := VisitTableName(in.Qualifier, f); err != nil { + return err + } + return nil } - -// VisitOnDup will visit all parts of the AST -func VisitOnDup(in OnDup, f Visit) error { +func VisitRefOfCollateExpr(in *CollateExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in { - if err := VisitRefOfUpdateExpr(el, f); err != nil { - return err - } + if err := VisitExpr(in.Expr, f); err != nil { + return err } return nil } - -// EqualsRefOfOptLike does deep equals between the two objects. -func EqualsRefOfOptLike(a, b *OptLike) bool { - if a == b { - return true +func VisitRefOfColumnDefinition(in *ColumnDefinition, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsTableName(a.LikeTable, b.LikeTable) -} - -// CloneRefOfOptLike creates a deep clone of the input. -func CloneRefOfOptLike(n *OptLike) *OptLike { - if n == nil { - return nil + if err := VisitColIdent(in.Name, f); err != nil { + return err } - out := *n - out.LikeTable = CloneTableName(n.LikeTable) - return &out + return nil } - -// VisitRefOfOptLike will visit all parts of the AST -func VisitRefOfOptLike(in *OptLike, f Visit) error { +func VisitRefOfColumnType(in *ColumnType, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableName(in.LikeTable, f); err != nil { + if err := VisitRefOfLiteral(in.Length, f); err != nil { return err } - return nil -} - -// EqualsRefOfOrExpr does deep equals between the two objects. -func EqualsRefOfOrExpr(a, b *OrExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false + if err := VisitRefOfLiteral(in.Scale, f); err != nil { + return err } - return EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.Right, b.Right) + return nil } - -// CloneRefOfOrExpr creates a deep clone of the input. -func CloneRefOfOrExpr(n *OrExpr) *OrExpr { - if n == nil { +func VisitRefOfCommit(in *Commit, f Visit) error { + if in == nil { return nil } - out := *n - out.Left = CloneExpr(n.Left) - out.Right = CloneExpr(n.Right) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfOrExpr will visit all parts of the AST -func VisitRefOfOrExpr(in *OrExpr, f Visit) error { +func VisitRefOfComparisonExpr(in *ComparisonExpr, f Visit) error { if in == nil { return nil } @@ -4544,33 +7317,24 @@ func VisitRefOfOrExpr(in *OrExpr, f Visit) error { if err := VisitExpr(in.Right, f); err != nil { return err } + if err := VisitExpr(in.Escape, f); err != nil { + return err + } return nil } - -// EqualsRefOfOrder does deep equals between the two objects. -func EqualsRefOfOrder(a, b *Order) bool { - if a == b { - return true +func VisitRefOfConstraintDefinition(in *ConstraintDefinition, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsExpr(a.Expr, b.Expr) && - a.Direction == b.Direction -} - -// CloneRefOfOrder creates a deep clone of the input. -func CloneRefOfOrder(n *Order) *Order { - if n == nil { - return nil + if err := VisitConstraintInfo(in.Details, f); err != nil { + return err } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out + return nil } - -// VisitRefOfOrder will visit all parts of the AST -func VisitRefOfOrder(in *Order, f Visit) error { +func VisitRefOfConvertExpr(in *ConvertExpr, f Visit) error { if in == nil { return nil } @@ -4580,807 +7344,569 @@ func VisitRefOfOrder(in *Order, f Visit) error { if err := VisitExpr(in.Expr, f); err != nil { return err } - return nil -} - -// EqualsOrderBy does deep equals between the two objects. -func EqualsOrderBy(a, b OrderBy) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfOrder(a[i], b[i]) { - return false - } - } - return true -} - -// CloneOrderBy creates a deep clone of the input. -func CloneOrderBy(n OrderBy) OrderBy { - res := make(OrderBy, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfOrder(x)) + if err := VisitRefOfConvertType(in.Type, f); err != nil { + return err } - return res + return nil } - -// VisitOrderBy will visit all parts of the AST -func VisitOrderBy(in OrderBy, f Visit) error { +func VisitRefOfConvertType(in *ConvertType, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in { - if err := VisitRefOfOrder(el, f); err != nil { - return err - } - } - return nil -} - -// EqualsRefOfOrderByOption does deep equals between the two objects. -func EqualsRefOfOrderByOption(a, b *OrderByOption) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false + if err := VisitRefOfLiteral(in.Length, f); err != nil { + return err } - return EqualsColumns(a.Cols, b.Cols) -} - -// CloneRefOfOrderByOption creates a deep clone of the input. -func CloneRefOfOrderByOption(n *OrderByOption) *OrderByOption { - if n == nil { - return nil + if err := VisitRefOfLiteral(in.Scale, f); err != nil { + return err } - out := *n - out.Cols = CloneColumns(n.Cols) - return &out + return nil } - -// VisitRefOfOrderByOption will visit all parts of the AST -func VisitRefOfOrderByOption(in *OrderByOption, f Visit) error { +func VisitRefOfConvertUsingExpr(in *ConvertUsingExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColumns(in.Cols, f); err != nil { + if err := VisitExpr(in.Expr, f); err != nil { return err } return nil } - -// EqualsRefOfOtherAdmin does deep equals between the two objects. -func EqualsRefOfOtherAdmin(a, b *OtherAdmin) bool { - if a == b { - return true +func VisitRefOfCreateDatabase(in *CreateDatabase, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// CloneRefOfOtherAdmin creates a deep clone of the input. -func CloneRefOfOtherAdmin(n *OtherAdmin) *OtherAdmin { - if n == nil { - return nil + if err := VisitComments(in.Comments, f); err != nil { + return err } - out := *n - return &out + return nil } - -// VisitRefOfOtherAdmin will visit all parts of the AST -func VisitRefOfOtherAdmin(in *OtherAdmin, f Visit) error { +func VisitRefOfCreateTable(in *CreateTable, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - return nil -} - -// EqualsRefOfOtherRead does deep equals between the two objects. -func EqualsRefOfOtherRead(a, b *OtherRead) bool { - if a == b { - return true + if err := VisitTableName(in.Table, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitRefOfTableSpec(in.TableSpec, f); err != nil { + return err } - return true -} - -// CloneRefOfOtherRead creates a deep clone of the input. -func CloneRefOfOtherRead(n *OtherRead) *OtherRead { - if n == nil { - return nil + if err := VisitRefOfOptLike(in.OptLike, f); err != nil { + return err } - out := *n - return &out + return nil } - -// VisitRefOfOtherRead will visit all parts of the AST -func VisitRefOfOtherRead(in *OtherRead, f Visit) error { +func VisitRefOfCreateView(in *CreateView, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - return nil -} - -// EqualsRefOfParenSelect does deep equals between the two objects. -func EqualsRefOfParenSelect(a, b *ParenSelect) bool { - if a == b { - return true + if err := VisitTableName(in.ViewName, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitColumns(in.Columns, f); err != nil { + return err } - return EqualsSelectStatement(a.Select, b.Select) -} - -// CloneRefOfParenSelect creates a deep clone of the input. -func CloneRefOfParenSelect(n *ParenSelect) *ParenSelect { - if n == nil { - return nil + if err := VisitSelectStatement(in.Select, f); err != nil { + return err } - out := *n - out.Select = CloneSelectStatement(n.Select) - return &out + return nil } - -// VisitRefOfParenSelect will visit all parts of the AST -func VisitRefOfParenSelect(in *ParenSelect, f Visit) error { +func VisitRefOfCurTimeFuncExpr(in *CurTimeFuncExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitSelectStatement(in.Select, f); err != nil { + if err := VisitColIdent(in.Name, f); err != nil { return err } - return nil -} - -// EqualsRefOfParenTableExpr does deep equals between the two objects. -func EqualsRefOfParenTableExpr(a, b *ParenTableExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false + if err := VisitExpr(in.Fsp, f); err != nil { + return err } - return EqualsTableExprs(a.Exprs, b.Exprs) + return nil } - -// CloneRefOfParenTableExpr creates a deep clone of the input. -func CloneRefOfParenTableExpr(n *ParenTableExpr) *ParenTableExpr { - if n == nil { +func VisitRefOfDefault(in *Default, f Visit) error { + if in == nil { return nil } - out := *n - out.Exprs = CloneTableExprs(n.Exprs) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfParenTableExpr will visit all parts of the AST -func VisitRefOfParenTableExpr(in *ParenTableExpr, f Visit) error { +func VisitRefOfDelete(in *Delete, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableExprs(in.Exprs, f); err != nil { + if err := VisitComments(in.Comments, f); err != nil { return err } - return nil -} - -// EqualsRefOfPartitionDefinition does deep equals between the two objects. -func EqualsRefOfPartitionDefinition(a, b *PartitionDefinition) bool { - if a == b { - return true + if err := VisitTableNames(in.Targets, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitTableExprs(in.TableExprs, f); err != nil { + return err } - return a.Maxvalue == b.Maxvalue && - EqualsColIdent(a.Name, b.Name) && - EqualsExpr(a.Limit, b.Limit) -} - -// CloneRefOfPartitionDefinition creates a deep clone of the input. -func CloneRefOfPartitionDefinition(n *PartitionDefinition) *PartitionDefinition { - if n == nil { - return nil + if err := VisitPartitions(in.Partitions, f); err != nil { + return err } - out := *n - out.Name = CloneColIdent(n.Name) - out.Limit = CloneExpr(n.Limit) - return &out + if err := VisitRefOfWhere(in.Where, f); err != nil { + return err + } + if err := VisitOrderBy(in.OrderBy, f); err != nil { + return err + } + if err := VisitRefOfLimit(in.Limit, f); err != nil { + return err + } + return nil } - -// VisitRefOfPartitionDefinition will visit all parts of the AST -func VisitRefOfPartitionDefinition(in *PartitionDefinition, f Visit) error { +func VisitRefOfDerivedTable(in *DerivedTable, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColIdent(in.Name, f); err != nil { - return err - } - if err := VisitExpr(in.Limit, f); err != nil { + if err := VisitSelectStatement(in.Select, f); err != nil { return err } return nil } - -// EqualsRefOfPartitionSpec does deep equals between the two objects. -func EqualsRefOfPartitionSpec(a, b *PartitionSpec) bool { - if a == b { - return true +func VisitRefOfDropColumn(in *DropColumn, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.IsAll == b.IsAll && - a.WithoutValidation == b.WithoutValidation && - a.Action == b.Action && - EqualsPartitions(a.Names, b.Names) && - EqualsRefOfLiteral(a.Number, b.Number) && - EqualsTableName(a.TableName, b.TableName) && - EqualsSliceOfRefOfPartitionDefinition(a.Definitions, b.Definitions) -} - -// CloneRefOfPartitionSpec creates a deep clone of the input. -func CloneRefOfPartitionSpec(n *PartitionSpec) *PartitionSpec { - if n == nil { - return nil + if err := VisitRefOfColName(in.Name, f); err != nil { + return err } - out := *n - out.Names = ClonePartitions(n.Names) - out.Number = CloneRefOfLiteral(n.Number) - out.TableName = CloneTableName(n.TableName) - out.Definitions = CloneSliceOfRefOfPartitionDefinition(n.Definitions) - return &out + return nil } - -// VisitRefOfPartitionSpec will visit all parts of the AST -func VisitRefOfPartitionSpec(in *PartitionSpec, f Visit) error { +func VisitRefOfDropDatabase(in *DropDatabase, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitPartitions(in.Names, f); err != nil { + if err := VisitComments(in.Comments, f); err != nil { return err } - if err := VisitRefOfLiteral(in.Number, f); err != nil { + return nil +} +func VisitRefOfDropKey(in *DropKey, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableName(in.TableName, f); err != nil { + return nil +} +func VisitRefOfDropTable(in *DropTable, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in.Definitions { - if err := VisitRefOfPartitionDefinition(el, f); err != nil { - return err - } + if err := VisitTableNames(in.FromTables, f); err != nil { + return err } return nil } - -// EqualsPartitions does deep equals between the two objects. -func EqualsPartitions(a, b Partitions) bool { - if len(a) != len(b) { - return false +func VisitRefOfDropView(in *DropView, f Visit) error { + if in == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsColIdent(a[i], b[i]) { - return false - } + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// ClonePartitions creates a deep clone of the input. -func ClonePartitions(n Partitions) Partitions { - res := make(Partitions, 0, len(n)) - for _, x := range n { - res = append(res, CloneColIdent(x)) + if err := VisitTableNames(in.FromTables, f); err != nil { + return err } - return res -} - -// VisitPartitions will visit all parts of the AST -func VisitPartitions(in Partitions, f Visit) error { + return nil +} +func VisitRefOfExistsExpr(in *ExistsExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in { - if err := VisitColIdent(el, f); err != nil { - return err - } + if err := VisitRefOfSubquery(in.Subquery, f); err != nil { + return err } return nil } - -// EqualsRefOfRangeCond does deep equals between the two objects. -func EqualsRefOfRangeCond(a, b *RangeCond) bool { - if a == b { - return true +func VisitRefOfExplainStmt(in *ExplainStmt, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Operator == b.Operator && - EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.From, b.From) && - EqualsExpr(a.To, b.To) -} - -// CloneRefOfRangeCond creates a deep clone of the input. -func CloneRefOfRangeCond(n *RangeCond) *RangeCond { - if n == nil { - return nil + if err := VisitStatement(in.Statement, f); err != nil { + return err } - out := *n - out.Left = CloneExpr(n.Left) - out.From = CloneExpr(n.From) - out.To = CloneExpr(n.To) - return &out + return nil } - -// VisitRefOfRangeCond will visit all parts of the AST -func VisitRefOfRangeCond(in *RangeCond, f Visit) error { +func VisitRefOfExplainTab(in *ExplainTab, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Left, f); err != nil { + if err := VisitTableName(in.Table, f); err != nil { return err } - if err := VisitExpr(in.From, f); err != nil { + return nil +} +func VisitRefOfFlush(in *Flush, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.To, f); err != nil { + if err := VisitTableNames(in.TableNames, f); err != nil { return err } return nil } - -// EqualsRefOfRelease does deep equals between the two objects. -func EqualsRefOfRelease(a, b *Release) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsColIdent(a.Name, b.Name) -} - -// CloneRefOfRelease creates a deep clone of the input. -func CloneRefOfRelease(n *Release) *Release { - if n == nil { +func VisitRefOfForce(in *Force, f Visit) error { + if in == nil { return nil } - out := *n - out.Name = CloneColIdent(n.Name) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfRelease will visit all parts of the AST -func VisitRefOfRelease(in *Release, f Visit) error { +func VisitRefOfForeignKeyDefinition(in *ForeignKeyDefinition, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColIdent(in.Name, f); err != nil { + if err := VisitColumns(in.Source, f); err != nil { return err } - return nil -} - -// EqualsRefOfRenameIndex does deep equals between the two objects. -func EqualsRefOfRenameIndex(a, b *RenameIndex) bool { - if a == b { - return true + if err := VisitTableName(in.ReferencedTable, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitColumns(in.ReferencedColumns, f); err != nil { + return err } - return a.OldName == b.OldName && - a.NewName == b.NewName -} - -// CloneRefOfRenameIndex creates a deep clone of the input. -func CloneRefOfRenameIndex(n *RenameIndex) *RenameIndex { - if n == nil { - return nil + if err := VisitReferenceAction(in.OnDelete, f); err != nil { + return err } - out := *n - return &out + if err := VisitReferenceAction(in.OnUpdate, f); err != nil { + return err + } + return nil } - -// VisitRefOfRenameIndex will visit all parts of the AST -func VisitRefOfRenameIndex(in *RenameIndex, f Visit) error { +func VisitRefOfFuncExpr(in *FuncExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - return nil -} - -// EqualsRefOfRenameTable does deep equals between the two objects. -func EqualsRefOfRenameTable(a, b *RenameTable) bool { - if a == b { - return true + if err := VisitTableIdent(in.Qualifier, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitColIdent(in.Name, f); err != nil { + return err } - return EqualsSliceOfRefOfRenameTablePair(a.TablePairs, b.TablePairs) -} - -// CloneRefOfRenameTable creates a deep clone of the input. -func CloneRefOfRenameTable(n *RenameTable) *RenameTable { - if n == nil { - return nil + if err := VisitSelectExprs(in.Exprs, f); err != nil { + return err } - out := *n - out.TablePairs = CloneSliceOfRefOfRenameTablePair(n.TablePairs) - return &out + return nil } - -// VisitRefOfRenameTable will visit all parts of the AST -func VisitRefOfRenameTable(in *RenameTable, f Visit) error { +func VisitRefOfGroupConcatExpr(in *GroupConcatExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - return nil -} - -// EqualsRefOfRenameTableName does deep equals between the two objects. -func EqualsRefOfRenameTableName(a, b *RenameTableName) bool { - if a == b { - return true + if err := VisitSelectExprs(in.Exprs, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitOrderBy(in.OrderBy, f); err != nil { + return err } - return EqualsTableName(a.Table, b.Table) -} - -// CloneRefOfRenameTableName creates a deep clone of the input. -func CloneRefOfRenameTableName(n *RenameTableName) *RenameTableName { - if n == nil { - return nil + if err := VisitRefOfLimit(in.Limit, f); err != nil { + return err } - out := *n - out.Table = CloneTableName(n.Table) - return &out + return nil } - -// VisitRefOfRenameTableName will visit all parts of the AST -func VisitRefOfRenameTableName(in *RenameTableName, f Visit) error { +func VisitRefOfIndexDefinition(in *IndexDefinition, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableName(in.Table, f); err != nil { + if err := VisitRefOfIndexInfo(in.Info, f); err != nil { return err } return nil } - -// EqualsRefOfRevertMigration does deep equals between the two objects. -func EqualsRefOfRevertMigration(a, b *RevertMigration) bool { - if a == b { - return true +func VisitRefOfIndexHints(in *IndexHints, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.UUID == b.UUID -} - -// CloneRefOfRevertMigration creates a deep clone of the input. -func CloneRefOfRevertMigration(n *RevertMigration) *RevertMigration { - if n == nil { - return nil + for _, el := range in.Indexes { + if err := VisitColIdent(el, f); err != nil { + return err + } } - out := *n - return &out + return nil } - -// VisitRefOfRevertMigration will visit all parts of the AST -func VisitRefOfRevertMigration(in *RevertMigration, f Visit) error { +func VisitRefOfIndexInfo(in *IndexInfo, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } + if err := VisitColIdent(in.Name, f); err != nil { + return err + } + if err := VisitColIdent(in.ConstraintName, f); err != nil { + return err + } return nil } - -// EqualsRefOfRollback does deep equals between the two objects. -func EqualsRefOfRollback(a, b *Rollback) bool { - if a == b { - return true +func VisitRefOfInsert(in *Insert, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return true + if err := VisitComments(in.Comments, f); err != nil { + return err + } + if err := VisitTableName(in.Table, f); err != nil { + return err + } + if err := VisitPartitions(in.Partitions, f); err != nil { + return err + } + if err := VisitColumns(in.Columns, f); err != nil { + return err + } + if err := VisitInsertRows(in.Rows, f); err != nil { + return err + } + if err := VisitOnDup(in.OnDup, f); err != nil { + return err + } + return nil } - -// CloneRefOfRollback creates a deep clone of the input. -func CloneRefOfRollback(n *Rollback) *Rollback { - if n == nil { +func VisitRefOfIntervalExpr(in *IntervalExpr, f Visit) error { + if in == nil { return nil } - out := *n - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitExpr(in.Expr, f); err != nil { + return err + } + return nil } - -// VisitRefOfRollback will visit all parts of the AST -func VisitRefOfRollback(in *Rollback, f Visit) error { +func VisitRefOfIsExpr(in *IsExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } + if err := VisitExpr(in.Expr, f); err != nil { + return err + } return nil } - -// EqualsRefOfSRollback does deep equals between the two objects. -func EqualsRefOfSRollback(a, b *SRollback) bool { - if a == b { - return true +func VisitRefOfJoinCondition(in *JoinCondition, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsColIdent(a.Name, b.Name) -} - -// CloneRefOfSRollback creates a deep clone of the input. -func CloneRefOfSRollback(n *SRollback) *SRollback { - if n == nil { - return nil + if err := VisitExpr(in.On, f); err != nil { + return err } - out := *n - out.Name = CloneColIdent(n.Name) - return &out + if err := VisitColumns(in.Using, f); err != nil { + return err + } + return nil } - -// VisitRefOfSRollback will visit all parts of the AST -func VisitRefOfSRollback(in *SRollback, f Visit) error { +func VisitRefOfJoinTableExpr(in *JoinTableExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColIdent(in.Name, f); err != nil { + if err := VisitTableExpr(in.LeftExpr, f); err != nil { return err } - return nil -} - -// EqualsRefOfSavepoint does deep equals between the two objects. -func EqualsRefOfSavepoint(a, b *Savepoint) bool { - if a == b { - return true + if err := VisitTableExpr(in.RightExpr, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitJoinCondition(in.Condition, f); err != nil { + return err } - return EqualsColIdent(a.Name, b.Name) + return nil } - -// CloneRefOfSavepoint creates a deep clone of the input. -func CloneRefOfSavepoint(n *Savepoint) *Savepoint { - if n == nil { +func VisitRefOfKeyState(in *KeyState, f Visit) error { + if in == nil { return nil } - out := *n - out.Name = CloneColIdent(n.Name) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfSavepoint will visit all parts of the AST -func VisitRefOfSavepoint(in *Savepoint, f Visit) error { +func VisitRefOfLimit(in *Limit, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColIdent(in.Name, f); err != nil { + if err := VisitExpr(in.Offset, f); err != nil { + return err + } + if err := VisitExpr(in.Rowcount, f); err != nil { return err } return nil } - -// EqualsRefOfSelect does deep equals between the two objects. -func EqualsRefOfSelect(a, b *Select) bool { - if a == b { - return true +func VisitRefOfLiteral(in *Literal, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Distinct == b.Distinct && - a.StraightJoinHint == b.StraightJoinHint && - a.SQLCalcFoundRows == b.SQLCalcFoundRows && - EqualsRefOfBool(a.Cache, b.Cache) && - EqualsComments(a.Comments, b.Comments) && - EqualsSelectExprs(a.SelectExprs, b.SelectExprs) && - EqualsTableExprs(a.From, b.From) && - EqualsRefOfWhere(a.Where, b.Where) && - EqualsGroupBy(a.GroupBy, b.GroupBy) && - EqualsRefOfWhere(a.Having, b.Having) && - EqualsOrderBy(a.OrderBy, b.OrderBy) && - EqualsRefOfLimit(a.Limit, b.Limit) && - a.Lock == b.Lock && - EqualsRefOfSelectInto(a.Into, b.Into) + return nil } - -// CloneRefOfSelect creates a deep clone of the input. -func CloneRefOfSelect(n *Select) *Select { - if n == nil { +func VisitRefOfLoad(in *Load, f Visit) error { + if in == nil { return nil } - out := *n - out.Cache = CloneRefOfBool(n.Cache) - out.Comments = CloneComments(n.Comments) - out.SelectExprs = CloneSelectExprs(n.SelectExprs) - out.From = CloneTableExprs(n.From) - out.Where = CloneRefOfWhere(n.Where) - out.GroupBy = CloneGroupBy(n.GroupBy) - out.Having = CloneRefOfWhere(n.Having) - out.OrderBy = CloneOrderBy(n.OrderBy) - out.Limit = CloneRefOfLimit(n.Limit) - out.Into = CloneRefOfSelectInto(n.Into) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfSelect will visit all parts of the AST -func VisitRefOfSelect(in *Select, f Visit) error { +func VisitRefOfLockOption(in *LockOption, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitComments(in.Comments, f); err != nil { - return err + return nil +} +func VisitRefOfLockTables(in *LockTables, f Visit) error { + if in == nil { + return nil } - if err := VisitSelectExprs(in.SelectExprs, f); err != nil { + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableExprs(in.From, f); err != nil { + return nil +} +func VisitRefOfMatchExpr(in *MatchExpr, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitRefOfWhere(in.Where, f); err != nil { + if err := VisitSelectExprs(in.Columns, f); err != nil { return err } - if err := VisitGroupBy(in.GroupBy, f); err != nil { + if err := VisitExpr(in.Expr, f); err != nil { return err } - if err := VisitRefOfWhere(in.Having, f); err != nil { + return nil +} +func VisitRefOfModifyColumn(in *ModifyColumn, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitOrderBy(in.OrderBy, f); err != nil { + if err := VisitRefOfColumnDefinition(in.NewColDefinition, f); err != nil { return err } - if err := VisitRefOfLimit(in.Limit, f); err != nil { + if err := VisitRefOfColName(in.First, f); err != nil { return err } - if err := VisitRefOfSelectInto(in.Into, f); err != nil { + if err := VisitRefOfColName(in.After, f); err != nil { return err } return nil } - -// EqualsSelectExprs does deep equals between the two objects. -func EqualsSelectExprs(a, b SelectExprs) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsSelectExpr(a[i], b[i]) { - return false - } - } - return true -} - -// CloneSelectExprs creates a deep clone of the input. -func CloneSelectExprs(n SelectExprs) SelectExprs { - res := make(SelectExprs, 0, len(n)) - for _, x := range n { - res = append(res, CloneSelectExpr(x)) - } - return res -} - -// VisitSelectExprs will visit all parts of the AST -func VisitSelectExprs(in SelectExprs, f Visit) error { +func VisitRefOfNextval(in *Nextval, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in { - if err := VisitSelectExpr(el, f); err != nil { - return err - } + if err := VisitExpr(in.Expr, f); err != nil { + return err } return nil } - -// EqualsRefOfSelectInto does deep equals between the two objects. -func EqualsRefOfSelectInto(a, b *SelectInto) bool { - if a == b { - return true +func VisitRefOfNotExpr(in *NotExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.FileName == b.FileName && - a.Charset == b.Charset && - a.FormatOption == b.FormatOption && - a.ExportOption == b.ExportOption && - a.Manifest == b.Manifest && - a.Overwrite == b.Overwrite && - a.Type == b.Type -} - -// CloneRefOfSelectInto creates a deep clone of the input. -func CloneRefOfSelectInto(n *SelectInto) *SelectInto { - if n == nil { - return nil + if err := VisitExpr(in.Expr, f); err != nil { + return err } - out := *n - return &out + return nil } - -// VisitRefOfSelectInto will visit all parts of the AST -func VisitRefOfSelectInto(in *SelectInto, f Visit) error { +func VisitRefOfNullVal(in *NullVal, f Visit) error { if in == nil { return nil } @@ -5389,714 +7915,490 @@ func VisitRefOfSelectInto(in *SelectInto, f Visit) error { } return nil } - -// EqualsRefOfSet does deep equals between the two objects. -func EqualsRefOfSet(a, b *Set) bool { - if a == b { - return true +func VisitRefOfOptLike(in *OptLike, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsComments(a.Comments, b.Comments) && - EqualsSetExprs(a.Exprs, b.Exprs) -} - -// CloneRefOfSet creates a deep clone of the input. -func CloneRefOfSet(n *Set) *Set { - if n == nil { - return nil + if err := VisitTableName(in.LikeTable, f); err != nil { + return err } - out := *n - out.Comments = CloneComments(n.Comments) - out.Exprs = CloneSetExprs(n.Exprs) - return &out -} - -// VisitRefOfSet will visit all parts of the AST -func VisitRefOfSet(in *Set, f Visit) error { + return nil +} +func VisitRefOfOrExpr(in *OrExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitComments(in.Comments, f); err != nil { + if err := VisitExpr(in.Left, f); err != nil { return err } - if err := VisitSetExprs(in.Exprs, f); err != nil { + if err := VisitExpr(in.Right, f); err != nil { return err } return nil } - -// EqualsRefOfSetExpr does deep equals between the two objects. -func EqualsRefOfSetExpr(a, b *SetExpr) bool { - if a == b { - return true +func VisitRefOfOrder(in *Order, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Scope == b.Scope && - EqualsColIdent(a.Name, b.Name) && - EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfSetExpr creates a deep clone of the input. -func CloneRefOfSetExpr(n *SetExpr) *SetExpr { - if n == nil { - return nil + if err := VisitExpr(in.Expr, f); err != nil { + return err } - out := *n - out.Name = CloneColIdent(n.Name) - out.Expr = CloneExpr(n.Expr) - return &out + return nil } - -// VisitRefOfSetExpr will visit all parts of the AST -func VisitRefOfSetExpr(in *SetExpr, f Visit) error { +func VisitRefOfOrderByOption(in *OrderByOption, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColIdent(in.Name, f); err != nil { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { + if err := VisitColumns(in.Cols, f); err != nil { return err } return nil } - -// EqualsSetExprs does deep equals between the two objects. -func EqualsSetExprs(a, b SetExprs) bool { - if len(a) != len(b) { - return false +func VisitRefOfOtherAdmin(in *OtherAdmin, f Visit) error { + if in == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfSetExpr(a[i], b[i]) { - return false - } + if cont, err := f(in); err != nil || !cont { + return err } - return true + return nil } - -// CloneSetExprs creates a deep clone of the input. -func CloneSetExprs(n SetExprs) SetExprs { - res := make(SetExprs, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfSetExpr(x)) +func VisitRefOfOtherRead(in *OtherRead, f Visit) error { + if in == nil { + return nil } - return res + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitSetExprs will visit all parts of the AST -func VisitSetExprs(in SetExprs, f Visit) error { +func VisitRefOfParenSelect(in *ParenSelect, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in { - if err := VisitRefOfSetExpr(el, f); err != nil { - return err - } + if err := VisitSelectStatement(in.Select, f); err != nil { + return err } return nil } - -// EqualsRefOfSetTransaction does deep equals between the two objects. -func EqualsRefOfSetTransaction(a, b *SetTransaction) bool { - if a == b { - return true +func VisitRefOfParenTableExpr(in *ParenTableExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsSQLNode(a.SQLNode, b.SQLNode) && - EqualsComments(a.Comments, b.Comments) && - a.Scope == b.Scope && - EqualsSliceOfCharacteristic(a.Characteristics, b.Characteristics) + if err := VisitTableExprs(in.Exprs, f); err != nil { + return err + } + return nil } - -// CloneRefOfSetTransaction creates a deep clone of the input. -func CloneRefOfSetTransaction(n *SetTransaction) *SetTransaction { - if n == nil { +func VisitRefOfPartitionDefinition(in *PartitionDefinition, f Visit) error { + if in == nil { return nil } - out := *n - out.SQLNode = CloneSQLNode(n.SQLNode) - out.Comments = CloneComments(n.Comments) - out.Characteristics = CloneSliceOfCharacteristic(n.Characteristics) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitColIdent(in.Name, f); err != nil { + return err + } + if err := VisitExpr(in.Limit, f); err != nil { + return err + } + return nil } - -// VisitRefOfSetTransaction will visit all parts of the AST -func VisitRefOfSetTransaction(in *SetTransaction, f Visit) error { +func VisitRefOfPartitionSpec(in *PartitionSpec, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitSQLNode(in.SQLNode, f); err != nil { + if err := VisitPartitions(in.Names, f); err != nil { return err } - if err := VisitComments(in.Comments, f); err != nil { + if err := VisitRefOfLiteral(in.Number, f); err != nil { return err } - for _, el := range in.Characteristics { - if err := VisitCharacteristic(el, f); err != nil { + if err := VisitTableName(in.TableName, f); err != nil { + return err + } + for _, el := range in.Definitions { + if err := VisitRefOfPartitionDefinition(el, f); err != nil { return err } } return nil } - -// EqualsRefOfShow does deep equals between the two objects. -func EqualsRefOfShow(a, b *Show) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsShowInternal(a.Internal, b.Internal) -} - -// CloneRefOfShow creates a deep clone of the input. -func CloneRefOfShow(n *Show) *Show { - if n == nil { - return nil - } - out := *n - out.Internal = CloneShowInternal(n.Internal) - return &out -} - -// VisitRefOfShow will visit all parts of the AST -func VisitRefOfShow(in *Show, f Visit) error { +func VisitRefOfRangeCond(in *RangeCond, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitShowInternal(in.Internal, f); err != nil { + if err := VisitExpr(in.Left, f); err != nil { return err } - return nil -} - -// EqualsRefOfShowBasic does deep equals between the two objects. -func EqualsRefOfShowBasic(a, b *ShowBasic) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false + if err := VisitExpr(in.From, f); err != nil { + return err } - return a.Full == b.Full && - a.DbName == b.DbName && - a.Command == b.Command && - EqualsTableName(a.Tbl, b.Tbl) && - EqualsRefOfShowFilter(a.Filter, b.Filter) -} - -// CloneRefOfShowBasic creates a deep clone of the input. -func CloneRefOfShowBasic(n *ShowBasic) *ShowBasic { - if n == nil { - return nil + if err := VisitExpr(in.To, f); err != nil { + return err } - out := *n - out.Tbl = CloneTableName(n.Tbl) - out.Filter = CloneRefOfShowFilter(n.Filter) - return &out + return nil } - -// VisitRefOfShowBasic will visit all parts of the AST -func VisitRefOfShowBasic(in *ShowBasic, f Visit) error { +func VisitRefOfRelease(in *Release, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableName(in.Tbl, f); err != nil { - return err - } - if err := VisitRefOfShowFilter(in.Filter, f); err != nil { + if err := VisitColIdent(in.Name, f); err != nil { return err } return nil } - -// EqualsRefOfShowCreate does deep equals between the two objects. -func EqualsRefOfShowCreate(a, b *ShowCreate) bool { - if a == b { - return true +func VisitRefOfRenameIndex(in *RenameIndex, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Command == b.Command && - EqualsTableName(a.Op, b.Op) + return nil } - -// CloneRefOfShowCreate creates a deep clone of the input. -func CloneRefOfShowCreate(n *ShowCreate) *ShowCreate { - if n == nil { +func VisitRefOfRenameTable(in *RenameTable, f Visit) error { + if in == nil { return nil } - out := *n - out.Op = CloneTableName(n.Op) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfShowCreate will visit all parts of the AST -func VisitRefOfShowCreate(in *ShowCreate, f Visit) error { +func VisitRefOfRenameTableName(in *RenameTableName, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableName(in.Op, f); err != nil { + if err := VisitTableName(in.Table, f); err != nil { return err } return nil } - -// EqualsRefOfShowFilter does deep equals between the two objects. -func EqualsRefOfShowFilter(a, b *ShowFilter) bool { - if a == b { - return true +func VisitRefOfRevertMigration(in *RevertMigration, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Like == b.Like && - EqualsExpr(a.Filter, b.Filter) + return nil } - -// CloneRefOfShowFilter creates a deep clone of the input. -func CloneRefOfShowFilter(n *ShowFilter) *ShowFilter { - if n == nil { +func VisitRefOfRollback(in *Rollback, f Visit) error { + if in == nil { return nil } - out := *n - out.Filter = CloneExpr(n.Filter) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfShowFilter will visit all parts of the AST -func VisitRefOfShowFilter(in *ShowFilter, f Visit) error { +func VisitRefOfSRollback(in *SRollback, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Filter, f); err != nil { + if err := VisitColIdent(in.Name, f); err != nil { return err } return nil } - -// EqualsRefOfShowLegacy does deep equals between the two objects. -func EqualsRefOfShowLegacy(a, b *ShowLegacy) bool { - if a == b { - return true +func VisitRefOfSavepoint(in *Savepoint, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Extended == b.Extended && - a.Type == b.Type && - EqualsTableName(a.OnTable, b.OnTable) && - EqualsTableName(a.Table, b.Table) && - EqualsRefOfShowTablesOpt(a.ShowTablesOpt, b.ShowTablesOpt) && - a.Scope == b.Scope && - EqualsExpr(a.ShowCollationFilterOpt, b.ShowCollationFilterOpt) -} - -// CloneRefOfShowLegacy creates a deep clone of the input. -func CloneRefOfShowLegacy(n *ShowLegacy) *ShowLegacy { - if n == nil { - return nil + if err := VisitColIdent(in.Name, f); err != nil { + return err } - out := *n - out.OnTable = CloneTableName(n.OnTable) - out.Table = CloneTableName(n.Table) - out.ShowTablesOpt = CloneRefOfShowTablesOpt(n.ShowTablesOpt) - out.ShowCollationFilterOpt = CloneExpr(n.ShowCollationFilterOpt) - return &out + return nil } - -// VisitRefOfShowLegacy will visit all parts of the AST -func VisitRefOfShowLegacy(in *ShowLegacy, f Visit) error { +func VisitRefOfSelect(in *Select, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableName(in.OnTable, f); err != nil { + if err := VisitComments(in.Comments, f); err != nil { return err } - if err := VisitTableName(in.Table, f); err != nil { + if err := VisitSelectExprs(in.SelectExprs, f); err != nil { return err } - if err := VisitExpr(in.ShowCollationFilterOpt, f); err != nil { + if err := VisitTableExprs(in.From, f); err != nil { return err } - return nil -} - -// EqualsRefOfStarExpr does deep equals between the two objects. -func EqualsRefOfStarExpr(a, b *StarExpr) bool { - if a == b { - return true + if err := VisitRefOfWhere(in.Where, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitGroupBy(in.GroupBy, f); err != nil { + return err } - return EqualsTableName(a.TableName, b.TableName) + if err := VisitRefOfWhere(in.Having, f); err != nil { + return err + } + if err := VisitOrderBy(in.OrderBy, f); err != nil { + return err + } + if err := VisitRefOfLimit(in.Limit, f); err != nil { + return err + } + if err := VisitRefOfSelectInto(in.Into, f); err != nil { + return err + } + return nil } - -// CloneRefOfStarExpr creates a deep clone of the input. -func CloneRefOfStarExpr(n *StarExpr) *StarExpr { - if n == nil { +func VisitRefOfSelectInto(in *SelectInto, f Visit) error { + if in == nil { return nil } - out := *n - out.TableName = CloneTableName(n.TableName) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfStarExpr will visit all parts of the AST -func VisitRefOfStarExpr(in *StarExpr, f Visit) error { +func VisitRefOfSet(in *Set, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableName(in.TableName, f); err != nil { + if err := VisitComments(in.Comments, f); err != nil { + return err + } + if err := VisitSetExprs(in.Exprs, f); err != nil { return err } return nil } - -// EqualsRefOfStream does deep equals between the two objects. -func EqualsRefOfStream(a, b *Stream) bool { - if a == b { - return true +func VisitRefOfSetExpr(in *SetExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsComments(a.Comments, b.Comments) && - EqualsSelectExpr(a.SelectExpr, b.SelectExpr) && - EqualsTableName(a.Table, b.Table) -} - -// CloneRefOfStream creates a deep clone of the input. -func CloneRefOfStream(n *Stream) *Stream { - if n == nil { - return nil + if err := VisitColIdent(in.Name, f); err != nil { + return err } - out := *n - out.Comments = CloneComments(n.Comments) - out.SelectExpr = CloneSelectExpr(n.SelectExpr) - out.Table = CloneTableName(n.Table) - return &out + if err := VisitExpr(in.Expr, f); err != nil { + return err + } + return nil } - -// VisitRefOfStream will visit all parts of the AST -func VisitRefOfStream(in *Stream, f Visit) error { +func VisitRefOfSetTransaction(in *SetTransaction, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitComments(in.Comments, f); err != nil { + if err := VisitSQLNode(in.SQLNode, f); err != nil { return err } - if err := VisitSelectExpr(in.SelectExpr, f); err != nil { + if err := VisitComments(in.Comments, f); err != nil { return err } - if err := VisitTableName(in.Table, f); err != nil { - return err + for _, el := range in.Characteristics { + if err := VisitCharacteristic(el, f); err != nil { + return err + } } return nil } - -// EqualsRefOfSubquery does deep equals between the two objects. -func EqualsRefOfSubquery(a, b *Subquery) bool { - if a == b { - return true +func VisitRefOfShow(in *Show, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsSelectStatement(a.Select, b.Select) -} - -// CloneRefOfSubquery creates a deep clone of the input. -func CloneRefOfSubquery(n *Subquery) *Subquery { - if n == nil { - return nil + if err := VisitShowInternal(in.Internal, f); err != nil { + return err } - out := *n - out.Select = CloneSelectStatement(n.Select) - return &out + return nil } - -// VisitRefOfSubquery will visit all parts of the AST -func VisitRefOfSubquery(in *Subquery, f Visit) error { +func VisitRefOfShowBasic(in *ShowBasic, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitSelectStatement(in.Select, f); err != nil { + if err := VisitTableName(in.Tbl, f); err != nil { + return err + } + if err := VisitRefOfShowFilter(in.Filter, f); err != nil { return err } return nil } - -// EqualsRefOfSubstrExpr does deep equals between the two objects. -func EqualsRefOfSubstrExpr(a, b *SubstrExpr) bool { - if a == b { - return true +func VisitRefOfShowCreate(in *ShowCreate, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsRefOfColName(a.Name, b.Name) && - EqualsRefOfLiteral(a.StrVal, b.StrVal) && - EqualsExpr(a.From, b.From) && - EqualsExpr(a.To, b.To) -} - -// CloneRefOfSubstrExpr creates a deep clone of the input. -func CloneRefOfSubstrExpr(n *SubstrExpr) *SubstrExpr { - if n == nil { - return nil + if err := VisitTableName(in.Op, f); err != nil { + return err } - out := *n - out.Name = CloneRefOfColName(n.Name) - out.StrVal = CloneRefOfLiteral(n.StrVal) - out.From = CloneExpr(n.From) - out.To = CloneExpr(n.To) - return &out + return nil } - -// VisitRefOfSubstrExpr will visit all parts of the AST -func VisitRefOfSubstrExpr(in *SubstrExpr, f Visit) error { +func VisitRefOfShowFilter(in *ShowFilter, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitRefOfColName(in.Name, f); err != nil { + if err := VisitExpr(in.Filter, f); err != nil { return err } - if err := VisitRefOfLiteral(in.StrVal, f); err != nil { + return nil +} +func VisitRefOfShowLegacy(in *ShowLegacy, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.From, f); err != nil { + if err := VisitTableName(in.OnTable, f); err != nil { return err } - if err := VisitExpr(in.To, f); err != nil { + if err := VisitTableName(in.Table, f); err != nil { + return err + } + if err := VisitExpr(in.ShowCollationFilterOpt, f); err != nil { return err } return nil } - -// EqualsTableExprs does deep equals between the two objects. -func EqualsTableExprs(a, b TableExprs) bool { - if len(a) != len(b) { - return false +func VisitRefOfStarExpr(in *StarExpr, f Visit) error { + if in == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsTableExpr(a[i], b[i]) { - return false - } + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// CloneTableExprs creates a deep clone of the input. -func CloneTableExprs(n TableExprs) TableExprs { - res := make(TableExprs, 0, len(n)) - for _, x := range n { - res = append(res, CloneTableExpr(x)) + if err := VisitTableName(in.TableName, f); err != nil { + return err } - return res + return nil } - -// VisitTableExprs will visit all parts of the AST -func VisitTableExprs(in TableExprs, f Visit) error { +func VisitRefOfStream(in *Stream, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in { - if err := VisitTableExpr(el, f); err != nil { - return err - } + if err := VisitComments(in.Comments, f); err != nil { + return err + } + if err := VisitSelectExpr(in.SelectExpr, f); err != nil { + return err + } + if err := VisitTableName(in.Table, f); err != nil { + return err } return nil } - -// EqualsTableIdent does deep equals between the two objects. -func EqualsTableIdent(a, b TableIdent) bool { - return a.v == b.v -} - -// CloneTableIdent creates a deep clone of the input. -func CloneTableIdent(n TableIdent) TableIdent { - return *CloneRefOfTableIdent(&n) -} - -// VisitTableIdent will visit all parts of the AST -func VisitTableIdent(in TableIdent, f Visit) error { +func VisitRefOfSubquery(in *Subquery, f Visit) error { + if in == nil { + return nil + } if cont, err := f(in); err != nil || !cont { return err } + if err := VisitSelectStatement(in.Select, f); err != nil { + return err + } return nil } - -// EqualsTableName does deep equals between the two objects. -func EqualsTableName(a, b TableName) bool { - return EqualsTableIdent(a.Name, b.Name) && - EqualsTableIdent(a.Qualifier, b.Qualifier) -} - -// CloneTableName creates a deep clone of the input. -func CloneTableName(n TableName) TableName { - return *CloneRefOfTableName(&n) -} - -// VisitTableName will visit all parts of the AST -func VisitTableName(in TableName, f Visit) error { +func VisitRefOfSubstrExpr(in *SubstrExpr, f Visit) error { + if in == nil { + return nil + } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableIdent(in.Name, f); err != nil { + if err := VisitRefOfColName(in.Name, f); err != nil { return err } - if err := VisitTableIdent(in.Qualifier, f); err != nil { + if err := VisitRefOfLiteral(in.StrVal, f); err != nil { return err } - return nil -} - -// EqualsTableNames does deep equals between the two objects. -func EqualsTableNames(a, b TableNames) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsTableName(a[i], b[i]) { - return false - } + if err := VisitExpr(in.From, f); err != nil { + return err } - return true -} - -// CloneTableNames creates a deep clone of the input. -func CloneTableNames(n TableNames) TableNames { - res := make(TableNames, 0, len(n)) - for _, x := range n { - res = append(res, CloneTableName(x)) + if err := VisitExpr(in.To, f); err != nil { + return err } - return res + return nil } - -// VisitTableNames will visit all parts of the AST -func VisitTableNames(in TableNames, f Visit) error { +func VisitRefOfTableIdent(in *TableIdent, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in { - if err := VisitTableName(el, f); err != nil { - return err - } - } return nil } - -// EqualsTableOptions does deep equals between the two objects. -func EqualsTableOptions(a, b TableOptions) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfTableOption(a[i], b[i]) { - return false - } - } - return true -} - -// CloneTableOptions creates a deep clone of the input. -func CloneTableOptions(n TableOptions) TableOptions { - res := make(TableOptions, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfTableOption(x)) +func VisitRefOfTableName(in *TableName, f Visit) error { + if in == nil { + return nil } - return res -} - -// VisitTableOptions will visit all parts of the AST -func VisitTableOptions(in TableOptions, f Visit) error { - _, err := f(in) - return err -} - -// EqualsRefOfTableSpec does deep equals between the two objects. -func EqualsRefOfTableSpec(a, b *TableSpec) bool { - if a == b { - return true + if cont, err := f(in); err != nil || !cont { + return err } - if a == nil || b == nil { - return false + if err := VisitTableIdent(in.Name, f); err != nil { + return err } - return EqualsSliceOfRefOfColumnDefinition(a.Columns, b.Columns) && - EqualsSliceOfRefOfIndexDefinition(a.Indexes, b.Indexes) && - EqualsSliceOfRefOfConstraintDefinition(a.Constraints, b.Constraints) && - EqualsTableOptions(a.Options, b.Options) -} - -// CloneRefOfTableSpec creates a deep clone of the input. -func CloneRefOfTableSpec(n *TableSpec) *TableSpec { - if n == nil { - return nil + if err := VisitTableIdent(in.Qualifier, f); err != nil { + return err } - out := *n - out.Columns = CloneSliceOfRefOfColumnDefinition(n.Columns) - out.Indexes = CloneSliceOfRefOfIndexDefinition(n.Indexes) - out.Constraints = CloneSliceOfRefOfConstraintDefinition(n.Constraints) - out.Options = CloneTableOptions(n.Options) - return &out + return nil } - -// VisitRefOfTableSpec will visit all parts of the AST func VisitRefOfTableSpec(in *TableSpec, f Visit) error { if in == nil { return nil @@ -6124,196 +8426,169 @@ func VisitRefOfTableSpec(in *TableSpec, f Visit) error { } return nil } - -// EqualsRefOfTablespaceOperation does deep equals between the two objects. -func EqualsRefOfTablespaceOperation(a, b *TablespaceOperation) bool { - if a == b { - return true +func VisitRefOfTablespaceOperation(in *TablespaceOperation, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Import == b.Import + return nil } - -// CloneRefOfTablespaceOperation creates a deep clone of the input. -func CloneRefOfTablespaceOperation(n *TablespaceOperation) *TablespaceOperation { - if n == nil { +func VisitRefOfTimestampFuncExpr(in *TimestampFuncExpr, f Visit) error { + if in == nil { return nil } - out := *n - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + if err := VisitExpr(in.Expr1, f); err != nil { + return err + } + if err := VisitExpr(in.Expr2, f); err != nil { + return err + } + return nil } - -// VisitRefOfTablespaceOperation will visit all parts of the AST -func VisitRefOfTablespaceOperation(in *TablespaceOperation, f Visit) error { +func VisitRefOfTruncateTable(in *TruncateTable, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } + if err := VisitTableName(in.Table, f); err != nil { + return err + } return nil } - -// EqualsRefOfTimestampFuncExpr does deep equals between the two objects. -func EqualsRefOfTimestampFuncExpr(a, b *TimestampFuncExpr) bool { - if a == b { - return true +func VisitRefOfUnaryExpr(in *UnaryExpr, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return a.Name == b.Name && - a.Unit == b.Unit && - EqualsExpr(a.Expr1, b.Expr1) && - EqualsExpr(a.Expr2, b.Expr2) -} - -// CloneRefOfTimestampFuncExpr creates a deep clone of the input. -func CloneRefOfTimestampFuncExpr(n *TimestampFuncExpr) *TimestampFuncExpr { - if n == nil { - return nil + if err := VisitExpr(in.Expr, f); err != nil { + return err } - out := *n - out.Expr1 = CloneExpr(n.Expr1) - out.Expr2 = CloneExpr(n.Expr2) - return &out + return nil } - -// VisitRefOfTimestampFuncExpr will visit all parts of the AST -func VisitRefOfTimestampFuncExpr(in *TimestampFuncExpr, f Visit) error { +func VisitRefOfUnion(in *Union, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Expr1, f); err != nil { + if err := VisitSelectStatement(in.FirstStatement, f); err != nil { return err } - if err := VisitExpr(in.Expr2, f); err != nil { + for _, el := range in.UnionSelects { + if err := VisitRefOfUnionSelect(el, f); err != nil { + return err + } + } + if err := VisitOrderBy(in.OrderBy, f); err != nil { + return err + } + if err := VisitRefOfLimit(in.Limit, f); err != nil { return err } return nil } - -// EqualsRefOfTruncateTable does deep equals between the two objects. -func EqualsRefOfTruncateTable(a, b *TruncateTable) bool { - if a == b { - return true +func VisitRefOfUnionSelect(in *UnionSelect, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsTableName(a.Table, b.Table) + if err := VisitSelectStatement(in.Statement, f); err != nil { + return err + } + return nil } - -// CloneRefOfTruncateTable creates a deep clone of the input. -func CloneRefOfTruncateTable(n *TruncateTable) *TruncateTable { - if n == nil { +func VisitRefOfUnlockTables(in *UnlockTables, f Visit) error { + if in == nil { return nil } - out := *n - out.Table = CloneTableName(n.Table) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfTruncateTable will visit all parts of the AST -func VisitRefOfTruncateTable(in *TruncateTable, f Visit) error { +func VisitRefOfUpdate(in *Update, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableName(in.Table, f); err != nil { + if err := VisitComments(in.Comments, f); err != nil { return err } - return nil -} - -// EqualsRefOfUnaryExpr does deep equals between the two objects. -func EqualsRefOfUnaryExpr(a, b *UnaryExpr) bool { - if a == b { - return true + if err := VisitTableExprs(in.TableExprs, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitUpdateExprs(in.Exprs, f); err != nil { + return err } - return a.Operator == b.Operator && - EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfUnaryExpr creates a deep clone of the input. -func CloneRefOfUnaryExpr(n *UnaryExpr) *UnaryExpr { - if n == nil { - return nil + if err := VisitRefOfWhere(in.Where, f); err != nil { + return err } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out + if err := VisitOrderBy(in.OrderBy, f); err != nil { + return err + } + if err := VisitRefOfLimit(in.Limit, f); err != nil { + return err + } + return nil } - -// VisitRefOfUnaryExpr will visit all parts of the AST -func VisitRefOfUnaryExpr(in *UnaryExpr, f Visit) error { +func VisitRefOfUpdateExpr(in *UpdateExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } + if err := VisitRefOfColName(in.Name, f); err != nil { + return err + } if err := VisitExpr(in.Expr, f); err != nil { return err } return nil } - -// EqualsRefOfUnion does deep equals between the two objects. -func EqualsRefOfUnion(a, b *Union) bool { - if a == b { - return true +func VisitRefOfUse(in *Use, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return EqualsSelectStatement(a.FirstStatement, b.FirstStatement) && - EqualsSliceOfRefOfUnionSelect(a.UnionSelects, b.UnionSelects) && - EqualsOrderBy(a.OrderBy, b.OrderBy) && - EqualsRefOfLimit(a.Limit, b.Limit) && - a.Lock == b.Lock -} - -// CloneRefOfUnion creates a deep clone of the input. -func CloneRefOfUnion(n *Union) *Union { - if n == nil { - return nil + if err := VisitTableIdent(in.DBName, f); err != nil { + return err } - out := *n - out.FirstStatement = CloneSelectStatement(n.FirstStatement) - out.UnionSelects = CloneSliceOfRefOfUnionSelect(n.UnionSelects) - out.OrderBy = CloneOrderBy(n.OrderBy) - out.Limit = CloneRefOfLimit(n.Limit) - return &out + return nil } - -// VisitRefOfUnion will visit all parts of the AST -func VisitRefOfUnion(in *Union, f Visit) error { +func VisitRefOfVStream(in *VStream, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitSelectStatement(in.FirstStatement, f); err != nil { + if err := VisitComments(in.Comments, f); err != nil { return err } - for _, el := range in.UnionSelects { - if err := VisitRefOfUnionSelect(el, f); err != nil { - return err - } + if err := VisitSelectExpr(in.SelectExpr, f); err != nil { + return err } - if err := VisitOrderBy(in.OrderBy, f); err != nil { + if err := VisitTableName(in.Table, f); err != nil { + return err + } + if err := VisitRefOfWhere(in.Where, f); err != nil { return err } if err := VisitRefOfLimit(in.Limit, f); err != nil { @@ -6321,199 +8596,424 @@ func VisitRefOfUnion(in *Union, f Visit) error { } return nil } - -// EqualsRefOfUnionSelect does deep equals between the two objects. -func EqualsRefOfUnionSelect(a, b *UnionSelect) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Distinct == b.Distinct && - EqualsSelectStatement(a.Statement, b.Statement) -} - -// CloneRefOfUnionSelect creates a deep clone of the input. -func CloneRefOfUnionSelect(n *UnionSelect) *UnionSelect { - if n == nil { +func VisitRefOfValidation(in *Validation, f Visit) error { + if in == nil { return nil } - out := *n - out.Statement = CloneSelectStatement(n.Statement) - return &out + if cont, err := f(in); err != nil || !cont { + return err + } + return nil } - -// VisitRefOfUnionSelect will visit all parts of the AST -func VisitRefOfUnionSelect(in *UnionSelect, f Visit) error { +func VisitRefOfValuesFuncExpr(in *ValuesFuncExpr, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitSelectStatement(in.Statement, f); err != nil { + if err := VisitRefOfColName(in.Name, f); err != nil { return err } return nil } - -// EqualsRefOfUnlockTables does deep equals between the two objects. -func EqualsRefOfUnlockTables(a, b *UnlockTables) bool { - if a == b { - return true +func VisitRefOfVindexParam(in *VindexParam, f Visit) error { + if in == nil { + return nil } - if a == nil || b == nil { - return false + if cont, err := f(in); err != nil || !cont { + return err } - return true -} - -// CloneRefOfUnlockTables creates a deep clone of the input. -func CloneRefOfUnlockTables(n *UnlockTables) *UnlockTables { - if n == nil { - return nil + if err := VisitColIdent(in.Key, f); err != nil { + return err } - out := *n - return &out + return nil } - -// VisitRefOfUnlockTables will visit all parts of the AST -func VisitRefOfUnlockTables(in *UnlockTables, f Visit) error { +func VisitRefOfVindexSpec(in *VindexSpec, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - return nil -} - -// EqualsRefOfUpdate does deep equals between the two objects. -func EqualsRefOfUpdate(a, b *Update) bool { - if a == b { - return true + if err := VisitColIdent(in.Name, f); err != nil { + return err } - if a == nil || b == nil { - return false + if err := VisitColIdent(in.Type, f); err != nil { + return err } - return EqualsComments(a.Comments, b.Comments) && - a.Ignore == b.Ignore && - EqualsTableExprs(a.TableExprs, b.TableExprs) && - EqualsUpdateExprs(a.Exprs, b.Exprs) && - EqualsRefOfWhere(a.Where, b.Where) && - EqualsOrderBy(a.OrderBy, b.OrderBy) && - EqualsRefOfLimit(a.Limit, b.Limit) -} - -// CloneRefOfUpdate creates a deep clone of the input. -func CloneRefOfUpdate(n *Update) *Update { - if n == nil { - return nil + for _, el := range in.Params { + if err := VisitVindexParam(el, f); err != nil { + return err + } } - out := *n - out.Comments = CloneComments(n.Comments) - out.TableExprs = CloneTableExprs(n.TableExprs) - out.Exprs = CloneUpdateExprs(n.Exprs) - out.Where = CloneRefOfWhere(n.Where) - out.OrderBy = CloneOrderBy(n.OrderBy) - out.Limit = CloneRefOfLimit(n.Limit) - return &out + return nil } - -// VisitRefOfUpdate will visit all parts of the AST -func VisitRefOfUpdate(in *Update, f Visit) error { +func VisitRefOfWhen(in *When, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitComments(in.Comments, f); err != nil { + if err := VisitExpr(in.Cond, f); err != nil { return err } - if err := VisitTableExprs(in.TableExprs, f); err != nil { + if err := VisitExpr(in.Val, f); err != nil { return err } - if err := VisitUpdateExprs(in.Exprs, f); err != nil { + return nil +} +func VisitRefOfWhere(in *Where, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitRefOfWhere(in.Where, f); err != nil { + if err := VisitExpr(in.Expr, f); err != nil { return err } - if err := VisitOrderBy(in.OrderBy, f); err != nil { + return nil +} +func VisitRefOfXorExpr(in *XorExpr, f Visit) error { + if in == nil { + return nil + } + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitRefOfLimit(in.Limit, f); err != nil { + if err := VisitExpr(in.Left, f); err != nil { + return err + } + if err := VisitExpr(in.Right, f); err != nil { return err } return nil } - -// EqualsRefOfUpdateExpr does deep equals between the two objects. -func EqualsRefOfUpdateExpr(a, b *UpdateExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsRefOfColName(a.Name, b.Name) && - EqualsExpr(a.Expr, b.Expr) +func VisitReferenceAction(in ReferenceAction, f Visit) error { + _, err := f(in) + return err } - -// CloneRefOfUpdateExpr creates a deep clone of the input. -func CloneRefOfUpdateExpr(n *UpdateExpr) *UpdateExpr { - if n == nil { +func VisitSQLNode(in SQLNode, f Visit) error { + if in == nil { + return nil + } + switch in := in.(type) { + case AccessMode: + return VisitAccessMode(in, f) + case *AddColumns: + return VisitRefOfAddColumns(in, f) + case *AddConstraintDefinition: + return VisitRefOfAddConstraintDefinition(in, f) + case *AddIndexDefinition: + return VisitRefOfAddIndexDefinition(in, f) + case AlgorithmValue: + return VisitAlgorithmValue(in, f) + case *AliasedExpr: + return VisitRefOfAliasedExpr(in, f) + case *AliasedTableExpr: + return VisitRefOfAliasedTableExpr(in, f) + case *AlterCharset: + return VisitRefOfAlterCharset(in, f) + case *AlterColumn: + return VisitRefOfAlterColumn(in, f) + case *AlterDatabase: + return VisitRefOfAlterDatabase(in, f) + case *AlterMigration: + return VisitRefOfAlterMigration(in, f) + case *AlterTable: + return VisitRefOfAlterTable(in, f) + case *AlterView: + return VisitRefOfAlterView(in, f) + case *AlterVschema: + return VisitRefOfAlterVschema(in, f) + case *AndExpr: + return VisitRefOfAndExpr(in, f) + case Argument: + return VisitArgument(in, f) + case *AutoIncSpec: + return VisitRefOfAutoIncSpec(in, f) + case *Begin: + return VisitRefOfBegin(in, f) + case *BinaryExpr: + return VisitRefOfBinaryExpr(in, f) + case BoolVal: + return VisitBoolVal(in, f) + case *CallProc: + return VisitRefOfCallProc(in, f) + case *CaseExpr: + return VisitRefOfCaseExpr(in, f) + case *ChangeColumn: + return VisitRefOfChangeColumn(in, f) + case *CheckConstraintDefinition: + return VisitRefOfCheckConstraintDefinition(in, f) + case ColIdent: + return VisitColIdent(in, f) + case *ColName: + return VisitRefOfColName(in, f) + case *CollateExpr: + return VisitRefOfCollateExpr(in, f) + case *ColumnDefinition: + return VisitRefOfColumnDefinition(in, f) + case *ColumnType: + return VisitRefOfColumnType(in, f) + case Columns: + return VisitColumns(in, f) + case Comments: + return VisitComments(in, f) + case *Commit: + return VisitRefOfCommit(in, f) + case *ComparisonExpr: + return VisitRefOfComparisonExpr(in, f) + case *ConstraintDefinition: + return VisitRefOfConstraintDefinition(in, f) + case *ConvertExpr: + return VisitRefOfConvertExpr(in, f) + case *ConvertType: + return VisitRefOfConvertType(in, f) + case *ConvertUsingExpr: + return VisitRefOfConvertUsingExpr(in, f) + case *CreateDatabase: + return VisitRefOfCreateDatabase(in, f) + case *CreateTable: + return VisitRefOfCreateTable(in, f) + case *CreateView: + return VisitRefOfCreateView(in, f) + case *CurTimeFuncExpr: + return VisitRefOfCurTimeFuncExpr(in, f) + case *Default: + return VisitRefOfDefault(in, f) + case *Delete: + return VisitRefOfDelete(in, f) + case *DerivedTable: + return VisitRefOfDerivedTable(in, f) + case *DropColumn: + return VisitRefOfDropColumn(in, f) + case *DropDatabase: + return VisitRefOfDropDatabase(in, f) + case *DropKey: + return VisitRefOfDropKey(in, f) + case *DropTable: + return VisitRefOfDropTable(in, f) + case *DropView: + return VisitRefOfDropView(in, f) + case *ExistsExpr: + return VisitRefOfExistsExpr(in, f) + case *ExplainStmt: + return VisitRefOfExplainStmt(in, f) + case *ExplainTab: + return VisitRefOfExplainTab(in, f) + case Exprs: + return VisitExprs(in, f) + case *Flush: + return VisitRefOfFlush(in, f) + case *Force: + return VisitRefOfForce(in, f) + case *ForeignKeyDefinition: + return VisitRefOfForeignKeyDefinition(in, f) + case *FuncExpr: + return VisitRefOfFuncExpr(in, f) + case GroupBy: + return VisitGroupBy(in, f) + case *GroupConcatExpr: + return VisitRefOfGroupConcatExpr(in, f) + case *IndexDefinition: + return VisitRefOfIndexDefinition(in, f) + case *IndexHints: + return VisitRefOfIndexHints(in, f) + case *IndexInfo: + return VisitRefOfIndexInfo(in, f) + case *Insert: + return VisitRefOfInsert(in, f) + case *IntervalExpr: + return VisitRefOfIntervalExpr(in, f) + case *IsExpr: + return VisitRefOfIsExpr(in, f) + case IsolationLevel: + return VisitIsolationLevel(in, f) + case JoinCondition: + return VisitJoinCondition(in, f) + case *JoinTableExpr: + return VisitRefOfJoinTableExpr(in, f) + case *KeyState: + return VisitRefOfKeyState(in, f) + case *Limit: + return VisitRefOfLimit(in, f) + case ListArg: + return VisitListArg(in, f) + case *Literal: + return VisitRefOfLiteral(in, f) + case *Load: + return VisitRefOfLoad(in, f) + case *LockOption: + return VisitRefOfLockOption(in, f) + case *LockTables: + return VisitRefOfLockTables(in, f) + case *MatchExpr: + return VisitRefOfMatchExpr(in, f) + case *ModifyColumn: + return VisitRefOfModifyColumn(in, f) + case *Nextval: + return VisitRefOfNextval(in, f) + case *NotExpr: + return VisitRefOfNotExpr(in, f) + case *NullVal: + return VisitRefOfNullVal(in, f) + case OnDup: + return VisitOnDup(in, f) + case *OptLike: + return VisitRefOfOptLike(in, f) + case *OrExpr: + return VisitRefOfOrExpr(in, f) + case *Order: + return VisitRefOfOrder(in, f) + case OrderBy: + return VisitOrderBy(in, f) + case *OrderByOption: + return VisitRefOfOrderByOption(in, f) + case *OtherAdmin: + return VisitRefOfOtherAdmin(in, f) + case *OtherRead: + return VisitRefOfOtherRead(in, f) + case *ParenSelect: + return VisitRefOfParenSelect(in, f) + case *ParenTableExpr: + return VisitRefOfParenTableExpr(in, f) + case *PartitionDefinition: + return VisitRefOfPartitionDefinition(in, f) + case *PartitionSpec: + return VisitRefOfPartitionSpec(in, f) + case Partitions: + return VisitPartitions(in, f) + case *RangeCond: + return VisitRefOfRangeCond(in, f) + case ReferenceAction: + return VisitReferenceAction(in, f) + case *Release: + return VisitRefOfRelease(in, f) + case *RenameIndex: + return VisitRefOfRenameIndex(in, f) + case *RenameTable: + return VisitRefOfRenameTable(in, f) + case *RenameTableName: + return VisitRefOfRenameTableName(in, f) + case *RevertMigration: + return VisitRefOfRevertMigration(in, f) + case *Rollback: + return VisitRefOfRollback(in, f) + case *SRollback: + return VisitRefOfSRollback(in, f) + case *Savepoint: + return VisitRefOfSavepoint(in, f) + case *Select: + return VisitRefOfSelect(in, f) + case SelectExprs: + return VisitSelectExprs(in, f) + case *SelectInto: + return VisitRefOfSelectInto(in, f) + case *Set: + return VisitRefOfSet(in, f) + case *SetExpr: + return VisitRefOfSetExpr(in, f) + case SetExprs: + return VisitSetExprs(in, f) + case *SetTransaction: + return VisitRefOfSetTransaction(in, f) + case *Show: + return VisitRefOfShow(in, f) + case *ShowBasic: + return VisitRefOfShowBasic(in, f) + case *ShowCreate: + return VisitRefOfShowCreate(in, f) + case *ShowFilter: + return VisitRefOfShowFilter(in, f) + case *ShowLegacy: + return VisitRefOfShowLegacy(in, f) + case *StarExpr: + return VisitRefOfStarExpr(in, f) + case *Stream: + return VisitRefOfStream(in, f) + case *Subquery: + return VisitRefOfSubquery(in, f) + case *SubstrExpr: + return VisitRefOfSubstrExpr(in, f) + case TableExprs: + return VisitTableExprs(in, f) + case TableIdent: + return VisitTableIdent(in, f) + case TableName: + return VisitTableName(in, f) + case TableNames: + return VisitTableNames(in, f) + case TableOptions: + return VisitTableOptions(in, f) + case *TableSpec: + return VisitRefOfTableSpec(in, f) + case *TablespaceOperation: + return VisitRefOfTablespaceOperation(in, f) + case *TimestampFuncExpr: + return VisitRefOfTimestampFuncExpr(in, f) + case *TruncateTable: + return VisitRefOfTruncateTable(in, f) + case *UnaryExpr: + return VisitRefOfUnaryExpr(in, f) + case *Union: + return VisitRefOfUnion(in, f) + case *UnionSelect: + return VisitRefOfUnionSelect(in, f) + case *UnlockTables: + return VisitRefOfUnlockTables(in, f) + case *Update: + return VisitRefOfUpdate(in, f) + case *UpdateExpr: + return VisitRefOfUpdateExpr(in, f) + case UpdateExprs: + return VisitUpdateExprs(in, f) + case *Use: + return VisitRefOfUse(in, f) + case *VStream: + return VisitRefOfVStream(in, f) + case ValTuple: + return VisitValTuple(in, f) + case *Validation: + return VisitRefOfValidation(in, f) + case Values: + return VisitValues(in, f) + case *ValuesFuncExpr: + return VisitRefOfValuesFuncExpr(in, f) + case VindexParam: + return VisitVindexParam(in, f) + case *VindexSpec: + return VisitRefOfVindexSpec(in, f) + case *When: + return VisitRefOfWhen(in, f) + case *Where: + return VisitRefOfWhere(in, f) + case *XorExpr: + return VisitRefOfXorExpr(in, f) + default: + // this should never happen return nil } - out := *n - out.Name = CloneRefOfColName(n.Name) - out.Expr = CloneExpr(n.Expr) - return &out } - -// VisitRefOfUpdateExpr will visit all parts of the AST -func VisitRefOfUpdateExpr(in *UpdateExpr, f Visit) error { +func VisitSelectExpr(in SelectExpr, f Visit) error { if in == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitRefOfColName(in.Name, f); err != nil { - return err - } - if err := VisitExpr(in.Expr, f); err != nil { - return err - } - return nil -} - -// EqualsUpdateExprs does deep equals between the two objects. -func EqualsUpdateExprs(a, b UpdateExprs) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfUpdateExpr(a[i], b[i]) { - return false - } - } - return true -} - -// CloneUpdateExprs creates a deep clone of the input. -func CloneUpdateExprs(n UpdateExprs) UpdateExprs { - res := make(UpdateExprs, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfUpdateExpr(x)) + switch in := in.(type) { + case *AliasedExpr: + return VisitRefOfAliasedExpr(in, f) + case *Nextval: + return VisitRefOfNextval(in, f) + case *StarExpr: + return VisitRefOfStarExpr(in, f) + default: + // this should never happen + return nil } - return res } - -// VisitUpdateExprs will visit all parts of the AST -func VisitUpdateExprs(in UpdateExprs, f Visit) error { +func VisitSelectExprs(in SelectExprs, f Visit) error { if in == nil { return nil } @@ -6521,196 +9021,181 @@ func VisitUpdateExprs(in UpdateExprs, f Visit) error { return err } for _, el := range in { - if err := VisitRefOfUpdateExpr(el, f); err != nil { + if err := VisitSelectExpr(el, f); err != nil { return err } } return nil } - -// EqualsRefOfUse does deep equals between the two objects. -func EqualsRefOfUse(a, b *Use) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +func VisitSelectStatement(in SelectStatement, f Visit) error { + if in == nil { + return nil } - return EqualsTableIdent(a.DBName, b.DBName) -} - -// CloneRefOfUse creates a deep clone of the input. -func CloneRefOfUse(n *Use) *Use { - if n == nil { + switch in := in.(type) { + case *ParenSelect: + return VisitRefOfParenSelect(in, f) + case *Select: + return VisitRefOfSelect(in, f) + case *Union: + return VisitRefOfUnion(in, f) + default: + // this should never happen return nil } - out := *n - out.DBName = CloneTableIdent(n.DBName) - return &out } - -// VisitRefOfUse will visit all parts of the AST -func VisitRefOfUse(in *Use, f Visit) error { +func VisitSetExprs(in SetExprs, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitTableIdent(in.DBName, f); err != nil { - return err + for _, el := range in { + if err := VisitRefOfSetExpr(el, f); err != nil { + return err + } } return nil } - -// EqualsRefOfVStream does deep equals between the two objects. -func EqualsRefOfVStream(a, b *VStream) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +func VisitShowInternal(in ShowInternal, f Visit) error { + if in == nil { + return nil } - return EqualsComments(a.Comments, b.Comments) && - EqualsSelectExpr(a.SelectExpr, b.SelectExpr) && - EqualsTableName(a.Table, b.Table) && - EqualsRefOfWhere(a.Where, b.Where) && - EqualsRefOfLimit(a.Limit, b.Limit) -} - -// CloneRefOfVStream creates a deep clone of the input. -func CloneRefOfVStream(n *VStream) *VStream { - if n == nil { + switch in := in.(type) { + case *ShowBasic: + return VisitRefOfShowBasic(in, f) + case *ShowCreate: + return VisitRefOfShowCreate(in, f) + case *ShowLegacy: + return VisitRefOfShowLegacy(in, f) + default: + // this should never happen return nil } - out := *n - out.Comments = CloneComments(n.Comments) - out.SelectExpr = CloneSelectExpr(n.SelectExpr) - out.Table = CloneTableName(n.Table) - out.Where = CloneRefOfWhere(n.Where) - out.Limit = CloneRefOfLimit(n.Limit) - return &out } - -// VisitRefOfVStream will visit all parts of the AST -func VisitRefOfVStream(in *VStream, f Visit) error { +func VisitSimpleTableExpr(in SimpleTableExpr, f Visit) error { if in == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitComments(in.Comments, f); err != nil { - return err - } - if err := VisitSelectExpr(in.SelectExpr, f); err != nil { - return err - } - if err := VisitTableName(in.Table, f); err != nil { - return err - } - if err := VisitRefOfWhere(in.Where, f); err != nil { - return err - } - if err := VisitRefOfLimit(in.Limit, f); err != nil { - return err - } - return nil -} - -// EqualsValTuple does deep equals between the two objects. -func EqualsValTuple(a, b ValTuple) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsExpr(a[i], b[i]) { - return false - } - } - return true -} - -// CloneValTuple creates a deep clone of the input. -func CloneValTuple(n ValTuple) ValTuple { - res := make(ValTuple, 0, len(n)) - for _, x := range n { - res = append(res, CloneExpr(x)) + switch in := in.(type) { + case *DerivedTable: + return VisitRefOfDerivedTable(in, f) + case TableName: + return VisitTableName(in, f) + default: + // this should never happen + return nil } - return res } - -// VisitValTuple will visit all parts of the AST -func VisitValTuple(in ValTuple, f Visit) error { +func VisitStatement(in Statement, f Visit) error { if in == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - for _, el := range in { - if err := VisitExpr(el, f); err != nil { - return err - } - } - return nil -} - -// EqualsRefOfValidation does deep equals between the two objects. -func EqualsRefOfValidation(a, b *Validation) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.With == b.With -} - -// CloneRefOfValidation creates a deep clone of the input. -func CloneRefOfValidation(n *Validation) *Validation { - if n == nil { + switch in := in.(type) { + case *AlterDatabase: + return VisitRefOfAlterDatabase(in, f) + case *AlterMigration: + return VisitRefOfAlterMigration(in, f) + case *AlterTable: + return VisitRefOfAlterTable(in, f) + case *AlterView: + return VisitRefOfAlterView(in, f) + case *AlterVschema: + return VisitRefOfAlterVschema(in, f) + case *Begin: + return VisitRefOfBegin(in, f) + case *CallProc: + return VisitRefOfCallProc(in, f) + case *Commit: + return VisitRefOfCommit(in, f) + case *CreateDatabase: + return VisitRefOfCreateDatabase(in, f) + case *CreateTable: + return VisitRefOfCreateTable(in, f) + case *CreateView: + return VisitRefOfCreateView(in, f) + case *Delete: + return VisitRefOfDelete(in, f) + case *DropDatabase: + return VisitRefOfDropDatabase(in, f) + case *DropTable: + return VisitRefOfDropTable(in, f) + case *DropView: + return VisitRefOfDropView(in, f) + case *ExplainStmt: + return VisitRefOfExplainStmt(in, f) + case *ExplainTab: + return VisitRefOfExplainTab(in, f) + case *Flush: + return VisitRefOfFlush(in, f) + case *Insert: + return VisitRefOfInsert(in, f) + case *Load: + return VisitRefOfLoad(in, f) + case *LockTables: + return VisitRefOfLockTables(in, f) + case *OtherAdmin: + return VisitRefOfOtherAdmin(in, f) + case *OtherRead: + return VisitRefOfOtherRead(in, f) + case *ParenSelect: + return VisitRefOfParenSelect(in, f) + case *Release: + return VisitRefOfRelease(in, f) + case *RenameTable: + return VisitRefOfRenameTable(in, f) + case *RevertMigration: + return VisitRefOfRevertMigration(in, f) + case *Rollback: + return VisitRefOfRollback(in, f) + case *SRollback: + return VisitRefOfSRollback(in, f) + case *Savepoint: + return VisitRefOfSavepoint(in, f) + case *Select: + return VisitRefOfSelect(in, f) + case *Set: + return VisitRefOfSet(in, f) + case *SetTransaction: + return VisitRefOfSetTransaction(in, f) + case *Show: + return VisitRefOfShow(in, f) + case *Stream: + return VisitRefOfStream(in, f) + case *TruncateTable: + return VisitRefOfTruncateTable(in, f) + case *Union: + return VisitRefOfUnion(in, f) + case *UnlockTables: + return VisitRefOfUnlockTables(in, f) + case *Update: + return VisitRefOfUpdate(in, f) + case *Use: + return VisitRefOfUse(in, f) + case *VStream: + return VisitRefOfVStream(in, f) + default: + // this should never happen return nil } - out := *n - return &out } - -// VisitRefOfValidation will visit all parts of the AST -func VisitRefOfValidation(in *Validation, f Visit) error { +func VisitTableExpr(in TableExpr, f Visit) error { if in == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil -} - -// EqualsValues does deep equals between the two objects. -func EqualsValues(a, b Values) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsValTuple(a[i], b[i]) { - return false - } - } - return true -} - -// CloneValues creates a deep clone of the input. -func CloneValues(n Values) Values { - res := make(Values, 0, len(n)) - for _, x := range n { - res = append(res, CloneValTuple(x)) + switch in := in.(type) { + case *AliasedTableExpr: + return VisitRefOfAliasedTableExpr(in, f) + case *JoinTableExpr: + return VisitRefOfJoinTableExpr(in, f) + case *ParenTableExpr: + return VisitRefOfParenTableExpr(in, f) + default: + // this should never happen + return nil } - return res } - -// VisitValues will visit all parts of the AST -func VisitValues(in Values, f Visit) error { +func VisitTableExprs(in TableExprs, f Visit) error { if in == nil { return nil } @@ -6718,2910 +9203,5369 @@ func VisitValues(in Values, f Visit) error { return err } for _, el := range in { - if err := VisitValTuple(el, f); err != nil { + if err := VisitTableExpr(el, f); err != nil { return err } } return nil } - -// EqualsRefOfValuesFuncExpr does deep equals between the two objects. -func EqualsRefOfValuesFuncExpr(a, b *ValuesFuncExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsRefOfColName(a.Name, b.Name) -} - -// CloneRefOfValuesFuncExpr creates a deep clone of the input. -func CloneRefOfValuesFuncExpr(n *ValuesFuncExpr) *ValuesFuncExpr { - if n == nil { - return nil - } - out := *n - out.Name = CloneRefOfColName(n.Name) - return &out -} - -// VisitRefOfValuesFuncExpr will visit all parts of the AST -func VisitRefOfValuesFuncExpr(in *ValuesFuncExpr, f Visit) error { - if in == nil { - return nil - } +func VisitTableIdent(in TableIdent, f Visit) error { if cont, err := f(in); err != nil || !cont { return err } - if err := VisitRefOfColName(in.Name, f); err != nil { - return err - } return nil } - -// EqualsVindexParam does deep equals between the two objects. -func EqualsVindexParam(a, b VindexParam) bool { - return a.Val == b.Val && - EqualsColIdent(a.Key, b.Key) -} - -// CloneVindexParam creates a deep clone of the input. -func CloneVindexParam(n VindexParam) VindexParam { - return *CloneRefOfVindexParam(&n) -} - -// VisitVindexParam will visit all parts of the AST -func VisitVindexParam(in VindexParam, f Visit) error { +func VisitTableName(in TableName, f Visit) error { if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColIdent(in.Key, f); err != nil { + if err := VisitTableIdent(in.Name, f); err != nil { return err } - return nil -} - -// EqualsRefOfVindexSpec does deep equals between the two objects. -func EqualsRefOfVindexSpec(a, b *VindexSpec) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsColIdent(a.Name, b.Name) && - EqualsColIdent(a.Type, b.Type) && - EqualsSliceOfVindexParam(a.Params, b.Params) -} - -// CloneRefOfVindexSpec creates a deep clone of the input. -func CloneRefOfVindexSpec(n *VindexSpec) *VindexSpec { - if n == nil { - return nil + if err := VisitTableIdent(in.Qualifier, f); err != nil { + return err } - out := *n - out.Name = CloneColIdent(n.Name) - out.Type = CloneColIdent(n.Type) - out.Params = CloneSliceOfVindexParam(n.Params) - return &out + return nil } - -// VisitRefOfVindexSpec will visit all parts of the AST -func VisitRefOfVindexSpec(in *VindexSpec, f Visit) error { +func VisitTableNames(in TableNames, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitColIdent(in.Name, f); err != nil { - return err - } - if err := VisitColIdent(in.Type, f); err != nil { - return err - } - for _, el := range in.Params { - if err := VisitVindexParam(el, f); err != nil { + for _, el := range in { + if err := VisitTableName(el, f); err != nil { return err } } return nil } - -// EqualsRefOfWhen does deep equals between the two objects. -func EqualsRefOfWhen(a, b *When) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsExpr(a.Cond, b.Cond) && - EqualsExpr(a.Val, b.Val) -} - -// CloneRefOfWhen creates a deep clone of the input. -func CloneRefOfWhen(n *When) *When { - if n == nil { - return nil - } - out := *n - out.Cond = CloneExpr(n.Cond) - out.Val = CloneExpr(n.Val) - return &out +func VisitTableOptions(in TableOptions, f Visit) error { + _, err := f(in) + return err } - -// VisitRefOfWhen will visit all parts of the AST -func VisitRefOfWhen(in *When, f Visit) error { +func VisitUpdateExprs(in UpdateExprs, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Cond, f); err != nil { - return err - } - if err := VisitExpr(in.Val, f); err != nil { - return err - } - return nil -} - -// EqualsRefOfWhere does deep equals between the two objects. -func EqualsRefOfWhere(a, b *Where) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Type == b.Type && - EqualsExpr(a.Expr, b.Expr) -} - -// CloneRefOfWhere creates a deep clone of the input. -func CloneRefOfWhere(n *Where) *Where { - if n == nil { - return nil + for _, el := range in { + if err := VisitRefOfUpdateExpr(el, f); err != nil { + return err + } } - out := *n - out.Expr = CloneExpr(n.Expr) - return &out + return nil } - -// VisitRefOfWhere will visit all parts of the AST -func VisitRefOfWhere(in *Where, f Visit) error { +func VisitValTuple(in ValTuple, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Expr, f); err != nil { - return err + for _, el := range in { + if err := VisitExpr(el, f); err != nil { + return err + } } return nil } - -// EqualsRefOfXorExpr does deep equals between the two objects. -func EqualsRefOfXorExpr(a, b *XorExpr) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsExpr(a.Left, b.Left) && - EqualsExpr(a.Right, b.Right) -} - -// CloneRefOfXorExpr creates a deep clone of the input. -func CloneRefOfXorExpr(n *XorExpr) *XorExpr { - if n == nil { - return nil - } - out := *n - out.Left = CloneExpr(n.Left) - out.Right = CloneExpr(n.Right) - return &out -} - -// VisitRefOfXorExpr will visit all parts of the AST -func VisitRefOfXorExpr(in *XorExpr, f Visit) error { +func VisitValues(in Values, f Visit) error { if in == nil { return nil } if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Left, f); err != nil { + for _, el := range in { + if err := VisitValTuple(el, f); err != nil { + return err + } + } + return nil +} +func VisitVindexParam(in VindexParam, f Visit) error { + if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Right, f); err != nil { + if err := VisitColIdent(in.Key, f); err != nil { return err } return nil } - -// EqualsAlterOption does deep equals between the two objects. -func EqualsAlterOption(inA, inB AlterOption) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteAccessMode(parent SQLNode, node AccessMode, replacer replacerFunc) error { + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - if inA == nil || inB == nil { - return false + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } - switch a := inA.(type) { - case *AddColumns: - b, ok := inB.(*AddColumns) - if !ok { - return false + return nil +} +func (a *application) rewriteAlgorithmValue(parent SQLNode, node AlgorithmValue, replacer replacerFunc) error { + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfAddColumns(a, b) - case *AddConstraintDefinition: - b, ok := inB.(*AddConstraintDefinition) - if !ok { - return false + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfAddConstraintDefinition(a, b) - case *AddIndexDefinition: - b, ok := inB.(*AddIndexDefinition) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfAddIndexDefinition(a, b) + } + return nil +} +func (a *application) rewriteAlterOption(parent SQLNode, node AlterOption, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AddColumns: + return a.rewriteRefOfAddColumns(parent, node, replacer) + case *AddConstraintDefinition: + return a.rewriteRefOfAddConstraintDefinition(parent, node, replacer) + case *AddIndexDefinition: + return a.rewriteRefOfAddIndexDefinition(parent, node, replacer) case AlgorithmValue: - b, ok := inB.(AlgorithmValue) - if !ok { - return false - } - return a == b + return a.rewriteAlgorithmValue(parent, node, replacer) case *AlterCharset: - b, ok := inB.(*AlterCharset) - if !ok { - return false - } - return EqualsRefOfAlterCharset(a, b) + return a.rewriteRefOfAlterCharset(parent, node, replacer) case *AlterColumn: - b, ok := inB.(*AlterColumn) - if !ok { - return false - } - return EqualsRefOfAlterColumn(a, b) + return a.rewriteRefOfAlterColumn(parent, node, replacer) case *ChangeColumn: - b, ok := inB.(*ChangeColumn) - if !ok { - return false - } - return EqualsRefOfChangeColumn(a, b) + return a.rewriteRefOfChangeColumn(parent, node, replacer) case *DropColumn: - b, ok := inB.(*DropColumn) - if !ok { - return false - } - return EqualsRefOfDropColumn(a, b) + return a.rewriteRefOfDropColumn(parent, node, replacer) case *DropKey: - b, ok := inB.(*DropKey) - if !ok { - return false - } - return EqualsRefOfDropKey(a, b) + return a.rewriteRefOfDropKey(parent, node, replacer) case *Force: - b, ok := inB.(*Force) - if !ok { - return false - } - return EqualsRefOfForce(a, b) + return a.rewriteRefOfForce(parent, node, replacer) case *KeyState: - b, ok := inB.(*KeyState) - if !ok { - return false - } - return EqualsRefOfKeyState(a, b) + return a.rewriteRefOfKeyState(parent, node, replacer) case *LockOption: - b, ok := inB.(*LockOption) - if !ok { - return false - } - return EqualsRefOfLockOption(a, b) + return a.rewriteRefOfLockOption(parent, node, replacer) case *ModifyColumn: - b, ok := inB.(*ModifyColumn) - if !ok { - return false - } - return EqualsRefOfModifyColumn(a, b) + return a.rewriteRefOfModifyColumn(parent, node, replacer) case *OrderByOption: - b, ok := inB.(*OrderByOption) - if !ok { - return false - } - return EqualsRefOfOrderByOption(a, b) + return a.rewriteRefOfOrderByOption(parent, node, replacer) case *RenameIndex: - b, ok := inB.(*RenameIndex) - if !ok { - return false - } - return EqualsRefOfRenameIndex(a, b) + return a.rewriteRefOfRenameIndex(parent, node, replacer) case *RenameTableName: - b, ok := inB.(*RenameTableName) - if !ok { - return false - } - return EqualsRefOfRenameTableName(a, b) + return a.rewriteRefOfRenameTableName(parent, node, replacer) case TableOptions: - b, ok := inB.(TableOptions) - if !ok { - return false - } - return EqualsTableOptions(a, b) + return a.rewriteTableOptions(parent, node, replacer) case *TablespaceOperation: - b, ok := inB.(*TablespaceOperation) - if !ok { - return false - } - return EqualsRefOfTablespaceOperation(a, b) + return a.rewriteRefOfTablespaceOperation(parent, node, replacer) case *Validation: - b, ok := inB.(*Validation) - if !ok { - return false + return a.rewriteRefOfValidation(parent, node, replacer) + default: + // this should never happen + return nil + } +} +func (a *application) rewriteArgument(parent SQLNode, node Argument, replacer replacerFunc) error { + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfValidation(a, b) + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteBoolVal(parent SQLNode, node BoolVal, replacer replacerFunc) error { + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteCharacteristic(parent SQLNode, node Characteristic, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case AccessMode: + return a.rewriteAccessMode(parent, node, replacer) + case IsolationLevel: + return a.rewriteIsolationLevel(parent, node, replacer) + default: + // this should never happen + return nil + } +} +func (a *application) rewriteColIdent(parent SQLNode, node ColIdent, replacer replacerFunc) error { + var err error + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if err != nil { + return err + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteColTuple(parent SQLNode, node ColTuple, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case ListArg: + return a.rewriteListArg(parent, node, replacer) + case *Subquery: + return a.rewriteRefOfSubquery(parent, node, replacer) + case ValTuple: + return a.rewriteValTuple(parent, node, replacer) + default: + // this should never happen + return nil + } +} +func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + for i, el := range node { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(Columns)[i] = newNode.(ColIdent) + }); errF != nil { + return errF + } + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteComments(parent SQLNode, node Comments, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteConstraintInfo(parent SQLNode, node ConstraintInfo, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *CheckConstraintDefinition: + return a.rewriteRefOfCheckConstraintDefinition(parent, node, replacer) + case *ForeignKeyDefinition: + return a.rewriteRefOfForeignKeyDefinition(parent, node, replacer) + default: + // this should never happen + return nil + } +} +func (a *application) rewriteDBDDLStatement(parent SQLNode, node DBDDLStatement, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AlterDatabase: + return a.rewriteRefOfAlterDatabase(parent, node, replacer) + case *CreateDatabase: + return a.rewriteRefOfCreateDatabase(parent, node, replacer) + case *DropDatabase: + return a.rewriteRefOfDropDatabase(parent, node, replacer) + default: + // this should never happen + return nil + } +} +func (a *application) rewriteDDLStatement(parent SQLNode, node DDLStatement, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *AlterTable: + return a.rewriteRefOfAlterTable(parent, node, replacer) + case *AlterView: + return a.rewriteRefOfAlterView(parent, node, replacer) + case *CreateTable: + return a.rewriteRefOfCreateTable(parent, node, replacer) + case *CreateView: + return a.rewriteRefOfCreateView(parent, node, replacer) + case *DropTable: + return a.rewriteRefOfDropTable(parent, node, replacer) + case *DropView: + return a.rewriteRefOfDropView(parent, node, replacer) + case *RenameTable: + return a.rewriteRefOfRenameTable(parent, node, replacer) + case *TruncateTable: + return a.rewriteRefOfTruncateTable(parent, node, replacer) default: // this should never happen - return false + return nil } } - -// CloneAlterOption creates a deep clone of the input. -func CloneAlterOption(in AlterOption) AlterOption { - if in == nil { +func (a *application) rewriteExplain(parent SQLNode, node Explain, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AddColumns: - return CloneRefOfAddColumns(in) - case *AddConstraintDefinition: - return CloneRefOfAddConstraintDefinition(in) - case *AddIndexDefinition: - return CloneRefOfAddIndexDefinition(in) - case AlgorithmValue: - return in - case *AlterCharset: - return CloneRefOfAlterCharset(in) - case *AlterColumn: - return CloneRefOfAlterColumn(in) - case *ChangeColumn: - return CloneRefOfChangeColumn(in) - case *DropColumn: - return CloneRefOfDropColumn(in) - case *DropKey: - return CloneRefOfDropKey(in) - case *Force: - return CloneRefOfForce(in) - case *KeyState: - return CloneRefOfKeyState(in) - case *LockOption: - return CloneRefOfLockOption(in) - case *ModifyColumn: - return CloneRefOfModifyColumn(in) - case *OrderByOption: - return CloneRefOfOrderByOption(in) - case *RenameIndex: - return CloneRefOfRenameIndex(in) - case *RenameTableName: - return CloneRefOfRenameTableName(in) - case TableOptions: - return CloneTableOptions(in) - case *TablespaceOperation: - return CloneRefOfTablespaceOperation(in) - case *Validation: - return CloneRefOfValidation(in) + switch node := node.(type) { + case *ExplainStmt: + return a.rewriteRefOfExplainStmt(parent, node, replacer) + case *ExplainTab: + return a.rewriteRefOfExplainTab(parent, node, replacer) default: // this should never happen return nil } } - -// VisitAlterOption will visit all parts of the AST -func VisitAlterOption(in AlterOption, f Visit) error { - if in == nil { +func (a *application) rewriteExpr(parent SQLNode, node Expr, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AddColumns: - return VisitRefOfAddColumns(in, f) - case *AddConstraintDefinition: - return VisitRefOfAddConstraintDefinition(in, f) - case *AddIndexDefinition: - return VisitRefOfAddIndexDefinition(in, f) - case AlgorithmValue: - return VisitAlgorithmValue(in, f) - case *AlterCharset: - return VisitRefOfAlterCharset(in, f) - case *AlterColumn: - return VisitRefOfAlterColumn(in, f) - case *ChangeColumn: - return VisitRefOfChangeColumn(in, f) - case *DropColumn: - return VisitRefOfDropColumn(in, f) - case *DropKey: - return VisitRefOfDropKey(in, f) - case *Force: - return VisitRefOfForce(in, f) - case *KeyState: - return VisitRefOfKeyState(in, f) - case *LockOption: - return VisitRefOfLockOption(in, f) - case *ModifyColumn: - return VisitRefOfModifyColumn(in, f) - case *OrderByOption: - return VisitRefOfOrderByOption(in, f) - case *RenameIndex: - return VisitRefOfRenameIndex(in, f) - case *RenameTableName: - return VisitRefOfRenameTableName(in, f) - case TableOptions: - return VisitTableOptions(in, f) - case *TablespaceOperation: - return VisitRefOfTablespaceOperation(in, f) - case *Validation: - return VisitRefOfValidation(in, f) + switch node := node.(type) { + case *AndExpr: + return a.rewriteRefOfAndExpr(parent, node, replacer) + case Argument: + return a.rewriteArgument(parent, node, replacer) + case *BinaryExpr: + return a.rewriteRefOfBinaryExpr(parent, node, replacer) + case BoolVal: + return a.rewriteBoolVal(parent, node, replacer) + case *CaseExpr: + return a.rewriteRefOfCaseExpr(parent, node, replacer) + case *ColName: + return a.rewriteRefOfColName(parent, node, replacer) + case *CollateExpr: + return a.rewriteRefOfCollateExpr(parent, node, replacer) + case *ComparisonExpr: + return a.rewriteRefOfComparisonExpr(parent, node, replacer) + case *ConvertExpr: + return a.rewriteRefOfConvertExpr(parent, node, replacer) + case *ConvertUsingExpr: + return a.rewriteRefOfConvertUsingExpr(parent, node, replacer) + case *CurTimeFuncExpr: + return a.rewriteRefOfCurTimeFuncExpr(parent, node, replacer) + case *Default: + return a.rewriteRefOfDefault(parent, node, replacer) + case *ExistsExpr: + return a.rewriteRefOfExistsExpr(parent, node, replacer) + case *FuncExpr: + return a.rewriteRefOfFuncExpr(parent, node, replacer) + case *GroupConcatExpr: + return a.rewriteRefOfGroupConcatExpr(parent, node, replacer) + case *IntervalExpr: + return a.rewriteRefOfIntervalExpr(parent, node, replacer) + case *IsExpr: + return a.rewriteRefOfIsExpr(parent, node, replacer) + case ListArg: + return a.rewriteListArg(parent, node, replacer) + case *Literal: + return a.rewriteRefOfLiteral(parent, node, replacer) + case *MatchExpr: + return a.rewriteRefOfMatchExpr(parent, node, replacer) + case *NotExpr: + return a.rewriteRefOfNotExpr(parent, node, replacer) + case *NullVal: + return a.rewriteRefOfNullVal(parent, node, replacer) + case *OrExpr: + return a.rewriteRefOfOrExpr(parent, node, replacer) + case *RangeCond: + return a.rewriteRefOfRangeCond(parent, node, replacer) + case *Subquery: + return a.rewriteRefOfSubquery(parent, node, replacer) + case *SubstrExpr: + return a.rewriteRefOfSubstrExpr(parent, node, replacer) + case *TimestampFuncExpr: + return a.rewriteRefOfTimestampFuncExpr(parent, node, replacer) + case *UnaryExpr: + return a.rewriteRefOfUnaryExpr(parent, node, replacer) + case ValTuple: + return a.rewriteValTuple(parent, node, replacer) + case *ValuesFuncExpr: + return a.rewriteRefOfValuesFuncExpr(parent, node, replacer) + case *XorExpr: + return a.rewriteRefOfXorExpr(parent, node, replacer) default: // this should never happen return nil } } - -// EqualsCharacteristic does deep equals between the two objects. -func EqualsCharacteristic(inA, inB Characteristic) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case AccessMode: - b, ok := inB.(AccessMode) - if !ok { - return false + for i, el := range node { + if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { + parent.(Exprs)[i] = newNode.(Expr) + }); errF != nil { + return errF } - return a == b - case IsolationLevel: - b, ok := inB.(IsolationLevel) - if !ok { - return false + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return a == b - default: - // this should never happen - return false } + return nil } - -// CloneCharacteristic creates a deep clone of the input. -func CloneCharacteristic(in Characteristic) Characteristic { - if in == nil { +func (a *application) rewriteGroupBy(parent SQLNode, node GroupBy, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case AccessMode: - return in - case IsolationLevel: - return in + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + for i, el := range node { + if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { + parent.(GroupBy)[i] = newNode.(Expr) + }); errF != nil { + return errF + } + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteInsertRows(parent SQLNode, node InsertRows, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case *ParenSelect: + return a.rewriteRefOfParenSelect(parent, node, replacer) + case *Select: + return a.rewriteRefOfSelect(parent, node, replacer) + case *Union: + return a.rewriteRefOfUnion(parent, node, replacer) + case Values: + return a.rewriteValues(parent, node, replacer) default: // this should never happen return nil } } - -// VisitCharacteristic will visit all parts of the AST -func VisitCharacteristic(in Characteristic, f Visit) error { - if in == nil { +func (a *application) rewriteIsolationLevel(parent SQLNode, node IsolationLevel, replacer replacerFunc) error { + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteJoinCondition(parent SQLNode, node JoinCondition, replacer replacerFunc) error { + var err error + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'On' on 'JoinCondition'") + }); errF != nil { + return errF + } + if errF := a.rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Using' on 'JoinCondition'") + }); errF != nil { + return errF + } + if err != nil { + return err + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteListArg(parent SQLNode, node ListArg, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case AccessMode: - return VisitAccessMode(in, f) - case IsolationLevel: - return VisitIsolationLevel(in, f) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + for i, el := range node { + if errF := a.rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { + parent.(OnDup)[i] = newNode.(*UpdateExpr) + }); errF != nil { + return errF + } + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsColTuple does deep equals between the two objects. -func EqualsColTuple(inA, inB ColTuple) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case ListArg: - b, ok := inB.(ListArg) - if !ok { - return false + for i, el := range node { + if errF := a.rewriteRefOfOrder(node, el, func(newNode, parent SQLNode) { + parent.(OrderBy)[i] = newNode.(*Order) + }); errF != nil { + return errF } - return EqualsListArg(a, b) - case *Subquery: - b, ok := inB.(*Subquery) - if !ok { - return false + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfSubquery(a, b) - case ValTuple: - b, ok := inB.(ValTuple) - if !ok { - return false + } + return nil +} +func (a *application) rewritePartitions(parent SQLNode, node Partitions, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + for i, el := range node { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(Partitions)[i] = newNode.(ColIdent) + }); errF != nil { + return errF + } + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + for i, el := range node.Columns { + if errF := a.rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*AddColumns).Columns[i] = newNode.(*ColumnDefinition) + }); errF != nil { + return errF + } + } + if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + parent.(*AddColumns).First = newNode.(*ColName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + parent.(*AddColumns).After = newNode.(*ColName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsValTuple(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneColTuple creates a deep clone of the input. -func CloneColTuple(in ColTuple) ColTuple { - if in == nil { +func (a *application) rewriteRefOfAddConstraintDefinition(parent SQLNode, node *AddConstraintDefinition, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case ListArg: - return CloneListArg(in) - case *Subquery: - return CloneRefOfSubquery(in) - case ValTuple: - return CloneValTuple(in) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteRefOfConstraintDefinition(node, node.ConstraintDefinition, func(newNode, parent SQLNode) { + parent.(*AddConstraintDefinition).ConstraintDefinition = newNode.(*ConstraintDefinition) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfAddIndexDefinition(parent SQLNode, node *AddIndexDefinition, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteRefOfIndexDefinition(node, node.IndexDefinition, func(newNode, parent SQLNode) { + parent.(*AddIndexDefinition).IndexDefinition = newNode.(*IndexDefinition) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitColTuple will visit all parts of the AST -func VisitColTuple(in ColTuple, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfAliasedExpr(parent SQLNode, node *AliasedExpr, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case ListArg: - return VisitListArg(in, f) - case *Subquery: - return VisitRefOfSubquery(in, f) - case ValTuple: - return VisitValTuple(in, f) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*AliasedExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteColIdent(node, node.As, func(newNode, parent SQLNode) { + parent.(*AliasedExpr).As = newNode.(ColIdent) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfAliasedTableExpr(parent SQLNode, node *AliasedTableExpr, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteSimpleTableExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Expr = newNode.(SimpleTableExpr) + }); errF != nil { + return errF + } + if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Partitions = newNode.(Partitions) + }); errF != nil { + return errF + } + if errF := a.rewriteTableIdent(node, node.As, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).As = newNode.(TableIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfIndexHints(node, node.Hints, func(newNode, parent SQLNode) { + parent.(*AliasedTableExpr).Hints = newNode.(*IndexHints) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsConstraintInfo does deep equals between the two objects. -func EqualsConstraintInfo(inA, inB ConstraintInfo) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfAlterCharset(parent SQLNode, node *AlterCharset, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case *CheckConstraintDefinition: - b, ok := inB.(*CheckConstraintDefinition) - if !ok { - return false + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfCheckConstraintDefinition(a, b) - case *ForeignKeyDefinition: - b, ok := inB.(*ForeignKeyDefinition) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfForeignKeyDefinition(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneConstraintInfo creates a deep clone of the input. -func CloneConstraintInfo(in ConstraintInfo) ConstraintInfo { - if in == nil { +func (a *application) rewriteRefOfAlterColumn(parent SQLNode, node *AlterColumn, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *CheckConstraintDefinition: - return CloneRefOfCheckConstraintDefinition(in) - case *ForeignKeyDefinition: - return CloneRefOfForeignKeyDefinition(in) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteRefOfColName(node, node.Column, func(newNode, parent SQLNode) { + parent.(*AlterColumn).Column = newNode.(*ColName) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.DefaultVal, func(newNode, parent SQLNode) { + parent.(*AlterColumn).DefaultVal = newNode.(Expr) + }); errF != nil { + return errF } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitConstraintInfo will visit all parts of the AST -func VisitConstraintInfo(in ConstraintInfo, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfAlterDatabase(parent SQLNode, node *AlterDatabase, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *CheckConstraintDefinition: - return VisitRefOfCheckConstraintDefinition(in, f) - case *ForeignKeyDefinition: - return VisitRefOfForeignKeyDefinition(in, f) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } -} - -// EqualsDBDDLStatement does deep equals between the two objects. -func EqualsDBDDLStatement(inA, inB DBDDLStatement) bool { - if inA == nil && inB == nil { - return true + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } - if inA == nil || inB == nil { - return false + return nil +} +func (a *application) rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigration, replacer replacerFunc) error { + if node == nil { + return nil } - switch a := inA.(type) { - case *AlterDatabase: - b, ok := inB.(*AlterDatabase) - if !ok { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfAlterDatabase(a, b) - case *CreateDatabase: - b, ok := inB.(*CreateDatabase) - if !ok { - return false + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfCreateDatabase(a, b) - case *DropDatabase: - b, ok := inB.(*DropDatabase) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfDropDatabase(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneDBDDLStatement creates a deep clone of the input. -func CloneDBDDLStatement(in DBDDLStatement) DBDDLStatement { - if in == nil { +func (a *application) rewriteRefOfAlterTable(parent SQLNode, node *AlterTable, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterDatabase: - return CloneRefOfAlterDatabase(in) - case *CreateDatabase: - return CloneRefOfCreateDatabase(in) - case *DropDatabase: - return CloneRefOfDropDatabase(in) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*AlterTable).Table = newNode.(TableName) + }); errF != nil { + return errF + } + for i, el := range node.AlterOptions { + if errF := a.rewriteAlterOption(node, el, func(newNode, parent SQLNode) { + parent.(*AlterTable).AlterOptions[i] = newNode.(AlterOption) + }); errF != nil { + return errF + } + } + if errF := a.rewriteRefOfPartitionSpec(node, node.PartitionSpec, func(newNode, parent SQLNode) { + parent.(*AlterTable).PartitionSpec = newNode.(*PartitionSpec) + }); errF != nil { + return errF } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitDBDDLStatement will visit all parts of the AST -func VisitDBDDLStatement(in DBDDLStatement, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfAlterView(parent SQLNode, node *AlterView, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterDatabase: - return VisitRefOfAlterDatabase(in, f) - case *CreateDatabase: - return VisitRefOfCreateDatabase(in, f) - case *DropDatabase: - return VisitRefOfDropDatabase(in, f) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { + parent.(*AlterView).ViewName = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*AlterView).Columns = newNode.(Columns) + }); errF != nil { + return errF + } + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*AlterView).Select = newNode.(SelectStatement) + }); errF != nil { + return errF } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsDDLStatement does deep equals between the two objects. -func EqualsDDLStatement(inA, inB DDLStatement) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschema, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case *AlterTable: - b, ok := inB.(*AlterTable) - if !ok { - return false + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*AlterVschema).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfVindexSpec(node, node.VindexSpec, func(newNode, parent SQLNode) { + parent.(*AlterVschema).VindexSpec = newNode.(*VindexSpec) + }); errF != nil { + return errF + } + for i, el := range node.VindexCols { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(*AlterVschema).VindexCols[i] = newNode.(ColIdent) + }); errF != nil { + return errF } - return EqualsRefOfAlterTable(a, b) - case *AlterView: - b, ok := inB.(*AlterView) - if !ok { - return false + } + if errF := a.rewriteRefOfAutoIncSpec(node, node.AutoIncSpec, func(newNode, parent SQLNode) { + parent.(*AlterVschema).AutoIncSpec = newNode.(*AutoIncSpec) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfAlterView(a, b) - case *CreateTable: - b, ok := inB.(*CreateTable) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfCreateTable(a, b) - case *CreateView: - b, ok := inB.(*CreateView) - if !ok { - return false + } + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*AndExpr).Left = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*AndExpr).Right = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfCreateView(a, b) - case *DropTable: - b, ok := inB.(*DropTable) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfAutoIncSpec(parent SQLNode, node *AutoIncSpec, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfDropTable(a, b) - case *DropView: - b, ok := inB.(*DropView) - if !ok { - return false + } + if errF := a.rewriteColIdent(node, node.Column, func(newNode, parent SQLNode) { + parent.(*AutoIncSpec).Column = newNode.(ColIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.Sequence, func(newNode, parent SQLNode) { + parent.(*AutoIncSpec).Sequence = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfBegin(parent SQLNode, node *Begin, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfDropView(a, b) - case *RenameTable: - b, ok := inB.(*RenameTable) - if !ok { - return false + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfRenameTable(a, b) - case *TruncateTable: - b, ok := inB.(*TruncateTable) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfTruncateTable(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneDDLStatement creates a deep clone of the input. -func CloneDDLStatement(in DDLStatement) DDLStatement { - if in == nil { +func (a *application) rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterTable: - return CloneRefOfAlterTable(in) - case *AlterView: - return CloneRefOfAlterView(in) - case *CreateTable: - return CloneRefOfCreateTable(in) - case *CreateView: - return CloneRefOfCreateView(in) - case *DropTable: - return CloneRefOfDropTable(in) - case *DropView: - return CloneRefOfDropView(in) - case *RenameTable: - return CloneRefOfRenameTable(in) - case *TruncateTable: - return CloneRefOfTruncateTable(in) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*BinaryExpr).Left = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*BinaryExpr).Right = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } } + return nil } - -// VisitDDLStatement will visit all parts of the AST -func VisitDDLStatement(in DDLStatement, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfCallProc(parent SQLNode, node *CallProc, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterTable: - return VisitRefOfAlterTable(in, f) - case *AlterView: - return VisitRefOfAlterView(in, f) - case *CreateTable: - return VisitRefOfCreateTable(in, f) - case *CreateView: - return VisitRefOfCreateView(in, f) - case *DropTable: - return VisitRefOfDropTable(in, f) - case *DropView: - return VisitRefOfDropView(in, f) - case *RenameTable: - return VisitRefOfRenameTable(in, f) - case *TruncateTable: - return VisitRefOfTruncateTable(in, f) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*CallProc).Name = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteExprs(node, node.Params, func(newNode, parent SQLNode) { + parent.(*CallProc).Params = newNode.(Exprs) + }); errF != nil { + return errF } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsExplain does deep equals between the two objects. -func EqualsExplain(inA, inB Explain) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfCaseExpr(parent SQLNode, node *CaseExpr, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case *ExplainStmt: - b, ok := inB.(*ExplainStmt) - if !ok { - return false + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*CaseExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + for i, el := range node.Whens { + if errF := a.rewriteRefOfWhen(node, el, func(newNode, parent SQLNode) { + parent.(*CaseExpr).Whens[i] = newNode.(*When) + }); errF != nil { + return errF } - return EqualsRefOfExplainStmt(a, b) - case *ExplainTab: - b, ok := inB.(*ExplainTab) - if !ok { - return false + } + if errF := a.rewriteExpr(node, node.Else, func(newNode, parent SQLNode) { + parent.(*CaseExpr).Else = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfExplainTab(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneExplain creates a deep clone of the input. -func CloneExplain(in Explain) Explain { - if in == nil { +func (a *application) rewriteRefOfChangeColumn(parent SQLNode, node *ChangeColumn, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ExplainStmt: - return CloneRefOfExplainStmt(in) - case *ExplainTab: - return CloneRefOfExplainTab(in) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteRefOfColName(node, node.OldColumn, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).OldColumn = newNode.(*ColName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).NewColDefinition = newNode.(*ColumnDefinition) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).First = newNode.(*ColName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + parent.(*ChangeColumn).After = newNode.(*ColName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } } + return nil } - -// VisitExplain will visit all parts of the AST -func VisitExplain(in Explain, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfCheckConstraintDefinition(parent SQLNode, node *CheckConstraintDefinition, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ExplainStmt: - return VisitRefOfExplainStmt(in, f) - case *ExplainTab: - return VisitRefOfExplainTab(in, f) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*CheckConstraintDefinition).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfColIdent(parent SQLNode, node *ColIdent, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsExpr does deep equals between the two objects. -func EqualsExpr(inA, inB Expr) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfColName(parent SQLNode, node *ColName, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case *AndExpr: - b, ok := inB.(*AndExpr) - if !ok { - return false + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*ColName).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.Qualifier, func(newNode, parent SQLNode) { + parent.(*ColName).Qualifier = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfAndExpr(a, b) - case Argument: - b, ok := inB.(Argument) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfCollateExpr(parent SQLNode, node *CollateExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return a == b - case *BinaryExpr: - b, ok := inB.(*BinaryExpr) - if !ok { - return false + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*CollateExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfBinaryExpr(a, b) - case BoolVal: - b, ok := inB.(BoolVal) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfColumnDefinition(parent SQLNode, node *ColumnDefinition, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return a == b - case *CaseExpr: - b, ok := inB.(*CaseExpr) - if !ok { - return false + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*ColumnDefinition).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfCaseExpr(a, b) - case *ColName: - b, ok := inB.(*ColName) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfColumnType(parent SQLNode, node *ColumnType, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfColName(a, b) - case *CollateExpr: - b, ok := inB.(*CollateExpr) - if !ok { - return false + } + if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { + parent.(*ColumnType).Length = newNode.(*Literal) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { + parent.(*ColumnType).Scale = newNode.(*Literal) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfCollateExpr(a, b) - case *ComparisonExpr: - b, ok := inB.(*ComparisonExpr) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfCommit(parent SQLNode, node *Commit, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfComparisonExpr(a, b) - case *ConvertExpr: - b, ok := inB.(*ConvertExpr) - if !ok { - return false + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfConvertExpr(a, b) - case *ConvertUsingExpr: - b, ok := inB.(*ConvertUsingExpr) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfConvertUsingExpr(a, b) - case *CurTimeFuncExpr: - b, ok := inB.(*CurTimeFuncExpr) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfComparisonExpr(parent SQLNode, node *ComparisonExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfCurTimeFuncExpr(a, b) - case *Default: - b, ok := inB.(*Default) - if !ok { - return false + } + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Left = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Right = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Escape, func(newNode, parent SQLNode) { + parent.(*ComparisonExpr).Escape = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfConstraintDefinition(parent SQLNode, node *ConstraintDefinition, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfDefault(a, b) - case *ExistsExpr: - b, ok := inB.(*ExistsExpr) - if !ok { - return false + } + if errF := a.rewriteConstraintInfo(node, node.Details, func(newNode, parent SQLNode) { + parent.(*ConstraintDefinition).Details = newNode.(ConstraintInfo) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfExistsExpr(a, b) - case *FuncExpr: - b, ok := inB.(*FuncExpr) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfConvertExpr(parent SQLNode, node *ConvertExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfFuncExpr(a, b) - case *GroupConcatExpr: - b, ok := inB.(*GroupConcatExpr) - if !ok { - return false + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*ConvertExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfConvertType(node, node.Type, func(newNode, parent SQLNode) { + parent.(*ConvertExpr).Type = newNode.(*ConvertType) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfGroupConcatExpr(a, b) - case *IntervalExpr: - b, ok := inB.(*IntervalExpr) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfConvertType(parent SQLNode, node *ConvertType, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfIntervalExpr(a, b) - case *IsExpr: - b, ok := inB.(*IsExpr) - if !ok { - return false + } + if errF := a.rewriteRefOfLiteral(node, node.Length, func(newNode, parent SQLNode) { + parent.(*ConvertType).Length = newNode.(*Literal) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLiteral(node, node.Scale, func(newNode, parent SQLNode) { + parent.(*ConvertType).Scale = newNode.(*Literal) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfIsExpr(a, b) - case ListArg: - b, ok := inB.(ListArg) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfConvertUsingExpr(parent SQLNode, node *ConvertUsingExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsListArg(a, b) - case *Literal: - b, ok := inB.(*Literal) - if !ok { - return false + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*ConvertUsingExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfLiteral(a, b) - case *MatchExpr: - b, ok := inB.(*MatchExpr) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfCreateDatabase(parent SQLNode, node *CreateDatabase, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfMatchExpr(a, b) - case *NotExpr: - b, ok := inB.(*NotExpr) - if !ok { - return false + } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*CreateDatabase).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfNotExpr(a, b) - case *NullVal: - b, ok := inB.(*NullVal) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfNullVal(a, b) - case *OrExpr: - b, ok := inB.(*OrExpr) - if !ok { - return false + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*CreateTable).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfTableSpec(node, node.TableSpec, func(newNode, parent SQLNode) { + parent.(*CreateTable).TableSpec = newNode.(*TableSpec) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfOptLike(node, node.OptLike, func(newNode, parent SQLNode) { + parent.(*CreateTable).OptLike = newNode.(*OptLike) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfOrExpr(a, b) - case *RangeCond: - b, ok := inB.(*RangeCond) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfCreateView(parent SQLNode, node *CreateView, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfRangeCond(a, b) - case *Subquery: - b, ok := inB.(*Subquery) - if !ok { - return false + } + if errF := a.rewriteTableName(node, node.ViewName, func(newNode, parent SQLNode) { + parent.(*CreateView).ViewName = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*CreateView).Columns = newNode.(Columns) + }); errF != nil { + return errF + } + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*CreateView).Select = newNode.(SelectStatement) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfSubquery(a, b) - case *SubstrExpr: - b, ok := inB.(*SubstrExpr) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfCurTimeFuncExpr(parent SQLNode, node *CurTimeFuncExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfSubstrExpr(a, b) - case *TimestampFuncExpr: - b, ok := inB.(*TimestampFuncExpr) - if !ok { - return false + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*CurTimeFuncExpr).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Fsp, func(newNode, parent SQLNode) { + parent.(*CurTimeFuncExpr).Fsp = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfTimestampFuncExpr(a, b) - case *UnaryExpr: - b, ok := inB.(*UnaryExpr) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfDefault(parent SQLNode, node *Default, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfUnaryExpr(a, b) - case ValTuple: - b, ok := inB.(ValTuple) - if !ok { - return false + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsValTuple(a, b) - case *ValuesFuncExpr: - b, ok := inB.(*ValuesFuncExpr) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfValuesFuncExpr(a, b) - case *XorExpr: - b, ok := inB.(*XorExpr) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfDelete(parent SQLNode, node *Delete, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfXorExpr(a, b) - default: - // this should never happen - return false } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Delete).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteTableNames(node, node.Targets, func(newNode, parent SQLNode) { + parent.(*Delete).Targets = newNode.(TableNames) + }); errF != nil { + return errF + } + if errF := a.rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { + parent.(*Delete).TableExprs = newNode.(TableExprs) + }); errF != nil { + return errF + } + if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + parent.(*Delete).Partitions = newNode.(Partitions) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*Delete).Where = newNode.(*Where) + }); errF != nil { + return errF + } + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Delete).OrderBy = newNode.(OrderBy) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Delete).Limit = newNode.(*Limit) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// CloneExpr creates a deep clone of the input. -func CloneExpr(in Expr) Expr { - if in == nil { +func (a *application) rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTable, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AndExpr: - return CloneRefOfAndExpr(in) - case Argument: - return in - case *BinaryExpr: - return CloneRefOfBinaryExpr(in) - case BoolVal: - return in - case *CaseExpr: - return CloneRefOfCaseExpr(in) - case *ColName: - return CloneRefOfColName(in) - case *CollateExpr: - return CloneRefOfCollateExpr(in) - case *ComparisonExpr: - return CloneRefOfComparisonExpr(in) - case *ConvertExpr: - return CloneRefOfConvertExpr(in) - case *ConvertUsingExpr: - return CloneRefOfConvertUsingExpr(in) - case *CurTimeFuncExpr: - return CloneRefOfCurTimeFuncExpr(in) - case *Default: - return CloneRefOfDefault(in) - case *ExistsExpr: - return CloneRefOfExistsExpr(in) - case *FuncExpr: - return CloneRefOfFuncExpr(in) - case *GroupConcatExpr: - return CloneRefOfGroupConcatExpr(in) - case *IntervalExpr: - return CloneRefOfIntervalExpr(in) - case *IsExpr: - return CloneRefOfIsExpr(in) - case ListArg: - return CloneListArg(in) - case *Literal: - return CloneRefOfLiteral(in) - case *MatchExpr: - return CloneRefOfMatchExpr(in) - case *NotExpr: - return CloneRefOfNotExpr(in) - case *NullVal: - return CloneRefOfNullVal(in) - case *OrExpr: - return CloneRefOfOrExpr(in) - case *RangeCond: - return CloneRefOfRangeCond(in) - case *Subquery: - return CloneRefOfSubquery(in) - case *SubstrExpr: - return CloneRefOfSubstrExpr(in) - case *TimestampFuncExpr: - return CloneRefOfTimestampFuncExpr(in) - case *UnaryExpr: - return CloneRefOfUnaryExpr(in) - case ValTuple: - return CloneValTuple(in) - case *ValuesFuncExpr: - return CloneRefOfValuesFuncExpr(in) - case *XorExpr: - return CloneRefOfXorExpr(in) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*DerivedTable).Select = newNode.(SelectStatement) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfDropColumn(parent SQLNode, node *DropColumn, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*DropColumn).Name = newNode.(*ColName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitExpr will visit all parts of the AST -func VisitExpr(in Expr, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfDropDatabase(parent SQLNode, node *DropDatabase, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AndExpr: - return VisitRefOfAndExpr(in, f) - case Argument: - return VisitArgument(in, f) - case *BinaryExpr: - return VisitRefOfBinaryExpr(in, f) - case BoolVal: - return VisitBoolVal(in, f) - case *CaseExpr: - return VisitRefOfCaseExpr(in, f) - case *ColName: - return VisitRefOfColName(in, f) - case *CollateExpr: - return VisitRefOfCollateExpr(in, f) - case *ComparisonExpr: - return VisitRefOfComparisonExpr(in, f) - case *ConvertExpr: - return VisitRefOfConvertExpr(in, f) - case *ConvertUsingExpr: - return VisitRefOfConvertUsingExpr(in, f) - case *CurTimeFuncExpr: - return VisitRefOfCurTimeFuncExpr(in, f) - case *Default: - return VisitRefOfDefault(in, f) - case *ExistsExpr: - return VisitRefOfExistsExpr(in, f) - case *FuncExpr: - return VisitRefOfFuncExpr(in, f) - case *GroupConcatExpr: - return VisitRefOfGroupConcatExpr(in, f) - case *IntervalExpr: - return VisitRefOfIntervalExpr(in, f) - case *IsExpr: - return VisitRefOfIsExpr(in, f) - case ListArg: - return VisitListArg(in, f) - case *Literal: - return VisitRefOfLiteral(in, f) - case *MatchExpr: - return VisitRefOfMatchExpr(in, f) - case *NotExpr: - return VisitRefOfNotExpr(in, f) - case *NullVal: - return VisitRefOfNullVal(in, f) - case *OrExpr: - return VisitRefOfOrExpr(in, f) - case *RangeCond: - return VisitRefOfRangeCond(in, f) - case *Subquery: - return VisitRefOfSubquery(in, f) - case *SubstrExpr: - return VisitRefOfSubstrExpr(in, f) - case *TimestampFuncExpr: - return VisitRefOfTimestampFuncExpr(in, f) - case *UnaryExpr: - return VisitRefOfUnaryExpr(in, f) - case ValTuple: - return VisitValTuple(in, f) - case *ValuesFuncExpr: - return VisitRefOfValuesFuncExpr(in, f) - case *XorExpr: - return VisitRefOfXorExpr(in, f) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*DropDatabase).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } } + return nil } - -// EqualsInsertRows does deep equals between the two objects. -func EqualsInsertRows(inA, inB InsertRows) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfDropKey(parent SQLNode, node *DropKey, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case *ParenSelect: - b, ok := inB.(*ParenSelect) - if !ok { - return false + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfParenSelect(a, b) - case *Select: - b, ok := inB.(*Select) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfSelect(a, b) - case *Union: - b, ok := inB.(*Union) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfDropTable(parent SQLNode, node *DropTable, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfUnion(a, b) - case Values: - b, ok := inB.(Values) - if !ok { - return false + } + if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { + parent.(*DropTable).FromTables = newNode.(TableNames) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsValues(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneInsertRows creates a deep clone of the input. -func CloneInsertRows(in InsertRows) InsertRows { - if in == nil { +func (a *application) rewriteRefOfDropView(parent SQLNode, node *DropView, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ParenSelect: - return CloneRefOfParenSelect(in) - case *Select: - return CloneRefOfSelect(in) - case *Union: - return CloneRefOfUnion(in) - case Values: - return CloneValues(in) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { + parent.(*DropView).FromTables = newNode.(TableNames) + }); errF != nil { + return errF } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitInsertRows will visit all parts of the AST -func VisitInsertRows(in InsertRows, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfExistsExpr(parent SQLNode, node *ExistsExpr, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ParenSelect: - return VisitRefOfParenSelect(in, f) - case *Select: - return VisitRefOfSelect(in, f) - case *Union: - return VisitRefOfUnion(in, f) - case Values: - return VisitValues(in, f) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { + parent.(*ExistsExpr).Subquery = newNode.(*Subquery) + }); errF != nil { + return errF } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsSelectExpr does deep equals between the two objects. -func EqualsSelectExpr(inA, inB SelectExpr) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case *AliasedExpr: - b, ok := inB.(*AliasedExpr) - if !ok { - return false + if errF := a.rewriteStatement(node, node.Statement, func(newNode, parent SQLNode) { + parent.(*ExplainStmt).Statement = newNode.(Statement) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfAliasedExpr(a, b) - case *Nextval: - b, ok := inB.(*Nextval) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfExplainTab(parent SQLNode, node *ExplainTab, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfNextval(a, b) - case *StarExpr: - b, ok := inB.(*StarExpr) - if !ok { - return false + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*ExplainTab).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfStarExpr(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneSelectExpr creates a deep clone of the input. -func CloneSelectExpr(in SelectExpr) SelectExpr { - if in == nil { +func (a *application) rewriteRefOfFlush(parent SQLNode, node *Flush, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AliasedExpr: - return CloneRefOfAliasedExpr(in) - case *Nextval: - return CloneRefOfNextval(in) - case *StarExpr: - return CloneRefOfStarExpr(in) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableNames(node, node.TableNames, func(newNode, parent SQLNode) { + parent.(*Flush).TableNames = newNode.(TableNames) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfForce(parent SQLNode, node *Force, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitSelectExpr will visit all parts of the AST -func VisitSelectExpr(in SelectExpr, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfForeignKeyDefinition(parent SQLNode, node *ForeignKeyDefinition, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AliasedExpr: - return VisitRefOfAliasedExpr(in, f) - case *Nextval: - return VisitRefOfNextval(in, f) - case *StarExpr: - return VisitRefOfStarExpr(in, f) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteColumns(node, node.Source, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).Source = newNode.(Columns) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.ReferencedTable, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).ReferencedTable = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteColumns(node, node.ReferencedColumns, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).ReferencedColumns = newNode.(Columns) + }); errF != nil { + return errF + } + if errF := a.rewriteReferenceAction(node, node.OnDelete, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).OnDelete = newNode.(ReferenceAction) + }); errF != nil { + return errF + } + if errF := a.rewriteReferenceAction(node, node.OnUpdate, func(newNode, parent SQLNode) { + parent.(*ForeignKeyDefinition).OnUpdate = newNode.(ReferenceAction) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Qualifier = newNode.(TableIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*FuncExpr).Exprs = newNode.(SelectExprs) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsSelectStatement does deep equals between the two objects. -func EqualsSelectStatement(inA, inB SelectStatement) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupConcatExpr, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case *ParenSelect: - b, ok := inB.(*ParenSelect) - if !ok { - return false + if errF := a.rewriteSelectExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).Exprs = newNode.(SelectExprs) + }); errF != nil { + return errF + } + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).OrderBy = newNode.(OrderBy) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).Limit = newNode.(*Limit) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfParenSelect(a, b) - case *Select: - b, ok := inB.(*Select) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfIndexDefinition(parent SQLNode, node *IndexDefinition, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfSelect(a, b) - case *Union: - b, ok := inB.(*Union) - if !ok { - return false + } + if errF := a.rewriteRefOfIndexInfo(node, node.Info, func(newNode, parent SQLNode) { + parent.(*IndexDefinition).Info = newNode.(*IndexInfo) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfUnion(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneSelectStatement creates a deep clone of the input. -func CloneSelectStatement(in SelectStatement) SelectStatement { - if in == nil { +func (a *application) rewriteRefOfIndexHints(parent SQLNode, node *IndexHints, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ParenSelect: - return CloneRefOfParenSelect(in) - case *Select: - return CloneRefOfSelect(in) - case *Union: - return CloneRefOfUnion(in) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + for i, el := range node.Indexes { + if errF := a.rewriteColIdent(node, el, func(newNode, parent SQLNode) { + parent.(*IndexHints).Indexes[i] = newNode.(ColIdent) + }); errF != nil { + return errF + } + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } } + return nil } - -// VisitSelectStatement will visit all parts of the AST -func VisitSelectStatement(in SelectStatement, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ParenSelect: - return VisitRefOfParenSelect(in, f) - case *Select: - return VisitRefOfSelect(in, f) - case *Union: - return VisitRefOfUnion(in, f) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } -} - -// EqualsShowInternal does deep equals between the two objects. -func EqualsShowInternal(inA, inB ShowInternal) bool { - if inA == nil && inB == nil { - return true + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*IndexInfo).Name = newNode.(ColIdent) + }); errF != nil { + return errF } - if inA == nil || inB == nil { - return false + if errF := a.rewriteColIdent(node, node.ConstraintName, func(newNode, parent SQLNode) { + parent.(*IndexInfo).ConstraintName = newNode.(ColIdent) + }); errF != nil { + return errF } - switch a := inA.(type) { - case *ShowBasic: - b, ok := inB.(*ShowBasic) - if !ok { - return false + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfShowBasic(a, b) - case *ShowCreate: - b, ok := inB.(*ShowCreate) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfShowCreate(a, b) - case *ShowLegacy: - b, ok := inB.(*ShowLegacy) - if !ok { - return false + } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Insert).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*Insert).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewritePartitions(node, node.Partitions, func(newNode, parent SQLNode) { + parent.(*Insert).Partitions = newNode.(Partitions) + }); errF != nil { + return errF + } + if errF := a.rewriteColumns(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*Insert).Columns = newNode.(Columns) + }); errF != nil { + return errF + } + if errF := a.rewriteInsertRows(node, node.Rows, func(newNode, parent SQLNode) { + parent.(*Insert).Rows = newNode.(InsertRows) + }); errF != nil { + return errF + } + if errF := a.rewriteOnDup(node, node.OnDup, func(newNode, parent SQLNode) { + parent.(*Insert).OnDup = newNode.(OnDup) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfShowLegacy(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneShowInternal creates a deep clone of the input. -func CloneShowInternal(in ShowInternal) ShowInternal { - if in == nil { +func (a *application) rewriteRefOfIntervalExpr(parent SQLNode, node *IntervalExpr, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ShowBasic: - return CloneRefOfShowBasic(in) - case *ShowCreate: - return CloneRefOfShowCreate(in) - case *ShowLegacy: - return CloneRefOfShowLegacy(in) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*IntervalExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfIsExpr(parent SQLNode, node *IsExpr, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*IsExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitShowInternal will visit all parts of the AST -func VisitShowInternal(in ShowInternal, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfJoinCondition(parent SQLNode, node *JoinCondition, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *ShowBasic: - return VisitRefOfShowBasic(in, f) - case *ShowCreate: - return VisitRefOfShowCreate(in, f) - case *ShowLegacy: - return VisitRefOfShowLegacy(in, f) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteExpr(node, node.On, func(newNode, parent SQLNode) { + parent.(*JoinCondition).On = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteColumns(node, node.Using, func(newNode, parent SQLNode) { + parent.(*JoinCondition).Using = newNode.(Columns) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfJoinTableExpr(parent SQLNode, node *JoinTableExpr, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableExpr(node, node.LeftExpr, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).LeftExpr = newNode.(TableExpr) + }); errF != nil { + return errF + } + if errF := a.rewriteTableExpr(node, node.RightExpr, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).RightExpr = newNode.(TableExpr) + }); errF != nil { + return errF + } + if errF := a.rewriteJoinCondition(node, node.Condition, func(newNode, parent SQLNode) { + parent.(*JoinTableExpr).Condition = newNode.(JoinCondition) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsSimpleTableExpr does deep equals between the two objects. -func EqualsSimpleTableExpr(inA, inB SimpleTableExpr) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfKeyState(parent SQLNode, node *KeyState, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case *DerivedTable: - b, ok := inB.(*DerivedTable) - if !ok { - return false + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfDerivedTable(a, b) - case TableName: - b, ok := inB.(TableName) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsTableName(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneSimpleTableExpr creates a deep clone of the input. -func CloneSimpleTableExpr(in SimpleTableExpr) SimpleTableExpr { - if in == nil { +func (a *application) rewriteRefOfLimit(parent SQLNode, node *Limit, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *DerivedTable: - return CloneRefOfDerivedTable(in) - case TableName: - return CloneTableName(in) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteExpr(node, node.Offset, func(newNode, parent SQLNode) { + parent.(*Limit).Offset = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Rowcount, func(newNode, parent SQLNode) { + parent.(*Limit).Rowcount = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } } + return nil } - -// VisitSimpleTableExpr will visit all parts of the AST -func VisitSimpleTableExpr(in SimpleTableExpr, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfLiteral(parent SQLNode, node *Literal, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *DerivedTable: - return VisitRefOfDerivedTable(in, f) - case TableName: - return VisitTableName(in, f) - default: - // this should never happen - return nil + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } + return nil } - -// EqualsStatement does deep equals between the two objects. -func EqualsStatement(inA, inB Statement) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfLoad(parent SQLNode, node *Load, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case *AlterDatabase: - b, ok := inB.(*AlterDatabase) - if !ok { - return false + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfAlterDatabase(a, b) - case *AlterMigration: - b, ok := inB.(*AlterMigration) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfAlterMigration(a, b) - case *AlterTable: - b, ok := inB.(*AlterTable) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfLockOption(parent SQLNode, node *LockOption, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfAlterTable(a, b) - case *AlterView: - b, ok := inB.(*AlterView) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfAlterView(a, b) - case *AlterVschema: - b, ok := inB.(*AlterVschema) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfLockTables(parent SQLNode, node *LockTables, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfAlterVschema(a, b) - case *Begin: - b, ok := inB.(*Begin) - if !ok { - return false + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfBegin(a, b) - case *CallProc: - b, ok := inB.(*CallProc) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfCallProc(a, b) - case *Commit: - b, ok := inB.(*Commit) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfCommit(a, b) - case *CreateDatabase: - b, ok := inB.(*CreateDatabase) - if !ok { - return false + } + if errF := a.rewriteSelectExprs(node, node.Columns, func(newNode, parent SQLNode) { + parent.(*MatchExpr).Columns = newNode.(SelectExprs) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*MatchExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfCreateDatabase(a, b) - case *CreateTable: - b, ok := inB.(*CreateTable) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfModifyColumn(parent SQLNode, node *ModifyColumn, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfCreateTable(a, b) - case *CreateView: - b, ok := inB.(*CreateView) - if !ok { - return false + } + if errF := a.rewriteRefOfColumnDefinition(node, node.NewColDefinition, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).NewColDefinition = newNode.(*ColumnDefinition) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfColName(node, node.First, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).First = newNode.(*ColName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfColName(node, node.After, func(newNode, parent SQLNode) { + parent.(*ModifyColumn).After = newNode.(*ColName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfCreateView(a, b) - case *Delete: - b, ok := inB.(*Delete) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfNextval(parent SQLNode, node *Nextval, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfDelete(a, b) - case *DropDatabase: - b, ok := inB.(*DropDatabase) - if !ok { - return false + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*Nextval).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfDropDatabase(a, b) - case *DropTable: - b, ok := inB.(*DropTable) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfNotExpr(parent SQLNode, node *NotExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfDropTable(a, b) - case *DropView: - b, ok := inB.(*DropView) - if !ok { - return false + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*NotExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfDropView(a, b) - case *ExplainStmt: - b, ok := inB.(*ExplainStmt) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfNullVal(parent SQLNode, node *NullVal, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfExplainStmt(a, b) - case *ExplainTab: - b, ok := inB.(*ExplainTab) - if !ok { - return false + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfExplainTab(a, b) - case *Flush: - b, ok := inB.(*Flush) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfFlush(a, b) - case *Insert: - b, ok := inB.(*Insert) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfOptLike(parent SQLNode, node *OptLike, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfInsert(a, b) - case *Load: - b, ok := inB.(*Load) - if !ok { - return false + } + if errF := a.rewriteTableName(node, node.LikeTable, func(newNode, parent SQLNode) { + parent.(*OptLike).LikeTable = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfLoad(a, b) - case *LockTables: - b, ok := inB.(*LockTables) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfOrExpr(parent SQLNode, node *OrExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfLockTables(a, b) - case *OtherAdmin: - b, ok := inB.(*OtherAdmin) - if !ok { - return false + } + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*OrExpr).Left = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*OrExpr).Right = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfOtherAdmin(a, b) - case *OtherRead: - b, ok := inB.(*OtherRead) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfOrder(parent SQLNode, node *Order, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfOtherRead(a, b) - case *ParenSelect: - b, ok := inB.(*ParenSelect) - if !ok { - return false + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*Order).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfParenSelect(a, b) - case *Release: - b, ok := inB.(*Release) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfOrderByOption(parent SQLNode, node *OrderByOption, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfRelease(a, b) - case *RenameTable: - b, ok := inB.(*RenameTable) - if !ok { - return false + } + if errF := a.rewriteColumns(node, node.Cols, func(newNode, parent SQLNode) { + parent.(*OrderByOption).Cols = newNode.(Columns) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfRenameTable(a, b) - case *RevertMigration: - b, ok := inB.(*RevertMigration) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfOtherAdmin(parent SQLNode, node *OtherAdmin, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfRevertMigration(a, b) - case *Rollback: - b, ok := inB.(*Rollback) - if !ok { - return false + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfRollback(a, b) - case *SRollback: - b, ok := inB.(*SRollback) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfSRollback(a, b) - case *Savepoint: - b, ok := inB.(*Savepoint) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfOtherRead(parent SQLNode, node *OtherRead, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfSavepoint(a, b) - case *Select: - b, ok := inB.(*Select) - if !ok { - return false + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfSelect(a, b) - case *Set: - b, ok := inB.(*Set) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfSet(a, b) - case *SetTransaction: - b, ok := inB.(*SetTransaction) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfParenSelect(parent SQLNode, node *ParenSelect, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfSetTransaction(a, b) - case *Show: - b, ok := inB.(*Show) - if !ok { - return false + } + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*ParenSelect).Select = newNode.(SelectStatement) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfShow(a, b) - case *Stream: - b, ok := inB.(*Stream) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfParenTableExpr(parent SQLNode, node *ParenTableExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfStream(a, b) - case *TruncateTable: - b, ok := inB.(*TruncateTable) - if !ok { - return false + } + if errF := a.rewriteTableExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*ParenTableExpr).Exprs = newNode.(TableExprs) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfTruncateTable(a, b) - case *Union: - b, ok := inB.(*Union) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfPartitionDefinition(parent SQLNode, node *PartitionDefinition, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfUnion(a, b) - case *UnlockTables: - b, ok := inB.(*UnlockTables) - if !ok { - return false + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*PartitionDefinition).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*PartitionDefinition).Limit = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfUnlockTables(a, b) - case *Update: - b, ok := inB.(*Update) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionSpec, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } - return EqualsRefOfUpdate(a, b) - case *Use: - b, ok := inB.(*Use) - if !ok { - return false + } + if errF := a.rewritePartitions(node, node.Names, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).Names = newNode.(Partitions) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLiteral(node, node.Number, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).Number = newNode.(*Literal) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).TableName = newNode.(TableName) + }); errF != nil { + return errF + } + for i, el := range node.Definitions { + if errF := a.rewriteRefOfPartitionDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*PartitionSpec).Definitions[i] = newNode.(*PartitionDefinition) + }); errF != nil { + return errF } - return EqualsRefOfUse(a, b) - case *VStream: - b, ok := inB.(*VStream) - if !ok { - return false + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfVStream(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneStatement creates a deep clone of the input. -func CloneStatement(in Statement) Statement { - if in == nil { +func (a *application) rewriteRefOfRangeCond(parent SQLNode, node *RangeCond, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterDatabase: - return CloneRefOfAlterDatabase(in) - case *AlterMigration: - return CloneRefOfAlterMigration(in) - case *AlterTable: - return CloneRefOfAlterTable(in) - case *AlterView: - return CloneRefOfAlterView(in) - case *AlterVschema: - return CloneRefOfAlterVschema(in) - case *Begin: - return CloneRefOfBegin(in) - case *CallProc: - return CloneRefOfCallProc(in) - case *Commit: - return CloneRefOfCommit(in) - case *CreateDatabase: - return CloneRefOfCreateDatabase(in) - case *CreateTable: - return CloneRefOfCreateTable(in) - case *CreateView: - return CloneRefOfCreateView(in) - case *Delete: - return CloneRefOfDelete(in) - case *DropDatabase: - return CloneRefOfDropDatabase(in) - case *DropTable: - return CloneRefOfDropTable(in) - case *DropView: - return CloneRefOfDropView(in) - case *ExplainStmt: - return CloneRefOfExplainStmt(in) - case *ExplainTab: - return CloneRefOfExplainTab(in) - case *Flush: - return CloneRefOfFlush(in) - case *Insert: - return CloneRefOfInsert(in) - case *Load: - return CloneRefOfLoad(in) - case *LockTables: - return CloneRefOfLockTables(in) - case *OtherAdmin: - return CloneRefOfOtherAdmin(in) - case *OtherRead: - return CloneRefOfOtherRead(in) - case *ParenSelect: - return CloneRefOfParenSelect(in) - case *Release: - return CloneRefOfRelease(in) - case *RenameTable: - return CloneRefOfRenameTable(in) - case *RevertMigration: - return CloneRefOfRevertMigration(in) - case *Rollback: - return CloneRefOfRollback(in) - case *SRollback: - return CloneRefOfSRollback(in) - case *Savepoint: - return CloneRefOfSavepoint(in) - case *Select: - return CloneRefOfSelect(in) - case *Set: - return CloneRefOfSet(in) - case *SetTransaction: - return CloneRefOfSetTransaction(in) - case *Show: - return CloneRefOfShow(in) - case *Stream: - return CloneRefOfStream(in) - case *TruncateTable: - return CloneRefOfTruncateTable(in) - case *Union: - return CloneRefOfUnion(in) - case *UnlockTables: - return CloneRefOfUnlockTables(in) - case *Update: - return CloneRefOfUpdate(in) - case *Use: - return CloneRefOfUse(in) - case *VStream: - return CloneRefOfVStream(in) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*RangeCond).Left = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.From, func(newNode, parent SQLNode) { + parent.(*RangeCond).From = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.To, func(newNode, parent SQLNode) { + parent.(*RangeCond).To = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfRelease(parent SQLNode, node *Release, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*Release).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitStatement will visit all parts of the AST -func VisitStatement(in Statement, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfRenameIndex(parent SQLNode, node *RenameIndex, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AlterDatabase: - return VisitRefOfAlterDatabase(in, f) - case *AlterMigration: - return VisitRefOfAlterMigration(in, f) - case *AlterTable: - return VisitRefOfAlterTable(in, f) - case *AlterView: - return VisitRefOfAlterView(in, f) - case *AlterVschema: - return VisitRefOfAlterVschema(in, f) - case *Begin: - return VisitRefOfBegin(in, f) - case *CallProc: - return VisitRefOfCallProc(in, f) - case *Commit: - return VisitRefOfCommit(in, f) - case *CreateDatabase: - return VisitRefOfCreateDatabase(in, f) - case *CreateTable: - return VisitRefOfCreateTable(in, f) - case *CreateView: - return VisitRefOfCreateView(in, f) - case *Delete: - return VisitRefOfDelete(in, f) - case *DropDatabase: - return VisitRefOfDropDatabase(in, f) - case *DropTable: - return VisitRefOfDropTable(in, f) - case *DropView: - return VisitRefOfDropView(in, f) - case *ExplainStmt: - return VisitRefOfExplainStmt(in, f) - case *ExplainTab: - return VisitRefOfExplainTab(in, f) - case *Flush: - return VisitRefOfFlush(in, f) - case *Insert: - return VisitRefOfInsert(in, f) - case *Load: - return VisitRefOfLoad(in, f) - case *LockTables: - return VisitRefOfLockTables(in, f) - case *OtherAdmin: - return VisitRefOfOtherAdmin(in, f) - case *OtherRead: - return VisitRefOfOtherRead(in, f) - case *ParenSelect: - return VisitRefOfParenSelect(in, f) - case *Release: - return VisitRefOfRelease(in, f) - case *RenameTable: - return VisitRefOfRenameTable(in, f) - case *RevertMigration: - return VisitRefOfRevertMigration(in, f) - case *Rollback: - return VisitRefOfRollback(in, f) - case *SRollback: - return VisitRefOfSRollback(in, f) - case *Savepoint: - return VisitRefOfSavepoint(in, f) - case *Select: - return VisitRefOfSelect(in, f) - case *Set: - return VisitRefOfSet(in, f) - case *SetTransaction: - return VisitRefOfSetTransaction(in, f) - case *Show: - return VisitRefOfShow(in, f) - case *Stream: - return VisitRefOfStream(in, f) - case *TruncateTable: - return VisitRefOfTruncateTable(in, f) - case *Union: - return VisitRefOfUnion(in, f) - case *UnlockTables: - return VisitRefOfUnlockTables(in, f) - case *Update: - return VisitRefOfUpdate(in, f) - case *Use: - return VisitRefOfUse(in, f) - case *VStream: - return VisitRefOfVStream(in, f) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfRenameTable(parent SQLNode, node *RenameTable, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfRenameTableName(parent SQLNode, node *RenameTableName, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*RenameTableName).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsTableExpr does deep equals between the two objects. -func EqualsTableExpr(inA, inB TableExpr) bool { - if inA == nil && inB == nil { - return true +func (a *application) rewriteRefOfRevertMigration(parent SQLNode, node *RevertMigration, replacer replacerFunc) error { + if node == nil { + return nil } - if inA == nil || inB == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - switch a := inA.(type) { - case *AliasedTableExpr: - b, ok := inB.(*AliasedTableExpr) - if !ok { - return false + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node } - return EqualsRefOfAliasedTableExpr(a, b) - case *JoinTableExpr: - b, ok := inB.(*JoinTableExpr) - if !ok { - return false + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfJoinTableExpr(a, b) - case *ParenTableExpr: - b, ok := inB.(*ParenTableExpr) - if !ok { - return false + } + return nil +} +func (a *application) rewriteRefOfRollback(parent SQLNode, node *Rollback, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort } - return EqualsRefOfParenTableExpr(a, b) - default: - // this should never happen - return false } + return nil } - -// CloneTableExpr creates a deep clone of the input. -func CloneTableExpr(in TableExpr) TableExpr { - if in == nil { +func (a *application) rewriteRefOfSRollback(parent SQLNode, node *SRollback, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AliasedTableExpr: - return CloneRefOfAliasedTableExpr(in) - case *JoinTableExpr: - return CloneRefOfJoinTableExpr(in) - case *ParenTableExpr: - return CloneRefOfParenTableExpr(in) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*SRollback).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfSavepoint(parent SQLNode, node *Savepoint, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*Savepoint).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitTableExpr will visit all parts of the AST -func VisitTableExpr(in TableExpr, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfSelect(parent SQLNode, node *Select, replacer replacerFunc) error { + if node == nil { return nil } - switch in := in.(type) { - case *AliasedTableExpr: - return VisitRefOfAliasedTableExpr(in, f) - case *JoinTableExpr: - return VisitRefOfJoinTableExpr(in, f) - case *ParenTableExpr: - return VisitRefOfParenTableExpr(in, f) - default: - // this should never happen + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Select).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteSelectExprs(node, node.SelectExprs, func(newNode, parent SQLNode) { + parent.(*Select).SelectExprs = newNode.(SelectExprs) + }); errF != nil { + return errF + } + if errF := a.rewriteTableExprs(node, node.From, func(newNode, parent SQLNode) { + parent.(*Select).From = newNode.(TableExprs) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*Select).Where = newNode.(*Where) + }); errF != nil { + return errF + } + if errF := a.rewriteGroupBy(node, node.GroupBy, func(newNode, parent SQLNode) { + parent.(*Select).GroupBy = newNode.(GroupBy) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfWhere(node, node.Having, func(newNode, parent SQLNode) { + parent.(*Select).Having = newNode.(*Where) + }); errF != nil { + return errF + } + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Select).OrderBy = newNode.(OrderBy) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Select).Limit = newNode.(*Limit) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfSelectInto(node, node.Into, func(newNode, parent SQLNode) { + parent.(*Select).Into = newNode.(*SelectInto) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfSelectInto(parent SQLNode, node *SelectInto, replacer replacerFunc) error { + if node == nil { return nil } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitAccessMode will visit all parts of the AST -func VisitAccessMode(in AccessMode, f Visit) error { - _, err := f(in) - return err +func (a *application) rewriteRefOfSet(parent SQLNode, node *Set, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Set).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteSetExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*Set).Exprs = newNode.(SetExprs) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitAlgorithmValue will visit all parts of the AST -func VisitAlgorithmValue(in AlgorithmValue, f Visit) error { - _, err := f(in) - return err +func (a *application) rewriteRefOfSetExpr(parent SQLNode, node *SetExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*SetExpr).Name = newNode.(ColIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*SetExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitArgument will visit all parts of the AST -func VisitArgument(in Argument, f Visit) error { - _, err := f(in) - return err +func (a *application) rewriteRefOfSetTransaction(parent SQLNode, node *SetTransaction, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteSQLNode(node, node.SQLNode, func(newNode, parent SQLNode) { + parent.(*SetTransaction).SQLNode = newNode.(SQLNode) + }); errF != nil { + return errF + } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*SetTransaction).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + for i, el := range node.Characteristics { + if errF := a.rewriteCharacteristic(node, el, func(newNode, parent SQLNode) { + parent.(*SetTransaction).Characteristics[i] = newNode.(Characteristic) + }); errF != nil { + return errF + } + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitBoolVal will visit all parts of the AST -func VisitBoolVal(in BoolVal, f Visit) error { - _, err := f(in) - return err +func (a *application) rewriteRefOfShow(parent SQLNode, node *Show, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteShowInternal(node, node.Internal, func(newNode, parent SQLNode) { + parent.(*Show).Internal = newNode.(ShowInternal) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitIsolationLevel will visit all parts of the AST -func VisitIsolationLevel(in IsolationLevel, f Visit) error { - _, err := f(in) - return err +func (a *application) rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableName(node, node.Tbl, func(newNode, parent SQLNode) { + parent.(*ShowBasic).Tbl = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfShowFilter(node, node.Filter, func(newNode, parent SQLNode) { + parent.(*ShowBasic).Filter = newNode.(*ShowFilter) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitReferenceAction will visit all parts of the AST -func VisitReferenceAction(in ReferenceAction, f Visit) error { - _, err := f(in) - return err +func (a *application) rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableName(node, node.Op, func(newNode, parent SQLNode) { + parent.(*ShowCreate).Op = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsSliceOfRefOfColumnDefinition does deep equals between the two objects. -func EqualsSliceOfRefOfColumnDefinition(a, b []*ColumnDefinition) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteExpr(node, node.Filter, func(newNode, parent SQLNode) { + parent.(*ShowFilter).Filter = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfShowLegacy(parent SQLNode, node *ShowLegacy, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableName(node, node.OnTable, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).OnTable = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.ShowCollationFilterOpt, func(newNode, parent SQLNode) { + parent.(*ShowLegacy).ShowCollationFilterOpt = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfStarExpr(parent SQLNode, node *StarExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableName(node, node.TableName, func(newNode, parent SQLNode) { + parent.(*StarExpr).TableName = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfStream(parent SQLNode, node *Stream, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Stream).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { + parent.(*Stream).SelectExpr = newNode.(SelectExpr) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*Stream).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfSubquery(parent SQLNode, node *Subquery, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteSelectStatement(node, node.Select, func(newNode, parent SQLNode) { + parent.(*Subquery).Select = newNode.(SelectStatement) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfSubstrExpr(parent SQLNode, node *SubstrExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).Name = newNode.(*ColName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLiteral(node, node.StrVal, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).StrVal = newNode.(*Literal) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.From, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).From = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.To, func(newNode, parent SQLNode) { + parent.(*SubstrExpr).To = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteRefOfTableIdent(parent SQLNode, node *TableIdent, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - for i := 0; i < len(a); i++ { - if !EqualsRefOfColumnDefinition(a[i], b[i]) { - return false + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort } } - return true + return nil } - -// CloneSliceOfRefOfColumnDefinition creates a deep clone of the input. -func CloneSliceOfRefOfColumnDefinition(n []*ColumnDefinition) []*ColumnDefinition { - res := make([]*ColumnDefinition, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfColumnDefinition(x)) +func (a *application) rewriteRefOfTableName(parent SQLNode, node *TableName, replacer replacerFunc) error { + if node == nil { + return nil } - return res -} - -// EqualsSliceOfCollateAndCharset does deep equals between the two objects. -func EqualsSliceOfCollateAndCharset(a, b []CollateAndCharset) bool { - if len(a) != len(b) { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - for i := 0; i < len(a); i++ { - if !EqualsCollateAndCharset(a[i], b[i]) { - return false + if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*TableName).Name = newNode.(TableIdent) + }); errF != nil { + return errF + } + if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + parent.(*TableName).Qualifier = newNode.(TableIdent) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } } - return true + return nil } - -// CloneSliceOfCollateAndCharset creates a deep clone of the input. -func CloneSliceOfCollateAndCharset(n []CollateAndCharset) []CollateAndCharset { - res := make([]CollateAndCharset, 0, len(n)) - for _, x := range n { - res = append(res, CloneCollateAndCharset(x)) +func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, replacer replacerFunc) error { + if node == nil { + return nil } - return res -} - -// EqualsSliceOfAlterOption does deep equals between the two objects. -func EqualsSliceOfAlterOption(a, b []AlterOption) bool { - if len(a) != len(b) { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - for i := 0; i < len(a); i++ { - if !EqualsAlterOption(a[i], b[i]) { - return false + for i, el := range node.Columns { + if errF := a.rewriteRefOfColumnDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*TableSpec).Columns[i] = newNode.(*ColumnDefinition) + }); errF != nil { + return errF } } - return true + for i, el := range node.Indexes { + if errF := a.rewriteRefOfIndexDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*TableSpec).Indexes[i] = newNode.(*IndexDefinition) + }); errF != nil { + return errF + } + } + for i, el := range node.Constraints { + if errF := a.rewriteRefOfConstraintDefinition(node, el, func(newNode, parent SQLNode) { + parent.(*TableSpec).Constraints[i] = newNode.(*ConstraintDefinition) + }); errF != nil { + return errF + } + } + if errF := a.rewriteTableOptions(node, node.Options, func(newNode, parent SQLNode) { + parent.(*TableSpec).Options = newNode.(TableOptions) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// CloneSliceOfAlterOption creates a deep clone of the input. -func CloneSliceOfAlterOption(n []AlterOption) []AlterOption { - res := make([]AlterOption, 0, len(n)) - for _, x := range n { - res = append(res, CloneAlterOption(x)) +func (a *application) rewriteRefOfTablespaceOperation(parent SQLNode, node *TablespaceOperation, replacer replacerFunc) error { + if node == nil { + return nil } - return res + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsSliceOfColIdent does deep equals between the two objects. -func EqualsSliceOfColIdent(a, b []ColIdent) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteRefOfTimestampFuncExpr(parent SQLNode, node *TimestampFuncExpr, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsColIdent(a[i], b[i]) { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } } - return true + if errF := a.rewriteExpr(node, node.Expr1, func(newNode, parent SQLNode) { + parent.(*TimestampFuncExpr).Expr1 = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Expr2, func(newNode, parent SQLNode) { + parent.(*TimestampFuncExpr).Expr2 = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// CloneSliceOfColIdent creates a deep clone of the input. -func CloneSliceOfColIdent(n []ColIdent) []ColIdent { - res := make([]ColIdent, 0, len(n)) - for _, x := range n { - res = append(res, CloneColIdent(x)) +func (a *application) rewriteRefOfTruncateTable(parent SQLNode, node *TruncateTable, replacer replacerFunc) error { + if node == nil { + return nil } - return res + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*TruncateTable).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsSliceOfRefOfWhen does deep equals between the two objects. -func EqualsSliceOfRefOfWhen(a, b []*When) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfWhen(a[i], b[i]) { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } } - return true + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*UnaryExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// CloneSliceOfRefOfWhen creates a deep clone of the input. -func CloneSliceOfRefOfWhen(n []*When) []*When { - res := make([]*When, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfWhen(x)) +func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer replacerFunc) error { + if node == nil { + return nil } - return res + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteSelectStatement(node, node.FirstStatement, func(newNode, parent SQLNode) { + parent.(*Union).FirstStatement = newNode.(SelectStatement) + }); errF != nil { + return errF + } + for i, el := range node.UnionSelects { + if errF := a.rewriteRefOfUnionSelect(node, el, func(newNode, parent SQLNode) { + parent.(*Union).UnionSelects[i] = newNode.(*UnionSelect) + }); errF != nil { + return errF + } + } + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Union).OrderBy = newNode.(OrderBy) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Union).Limit = newNode.(*Limit) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsRefOfColIdent does deep equals between the two objects. -func EqualsRefOfColIdent(a, b *ColIdent) bool { - if a == b { - return true +func (a *application) rewriteRefOfUnionSelect(parent SQLNode, node *UnionSelect, replacer replacerFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - return a.val == b.val && - a.lowered == b.lowered && - a.at == b.at + if errF := a.rewriteSelectStatement(node, node.Statement, func(newNode, parent SQLNode) { + parent.(*UnionSelect).Statement = newNode.(SelectStatement) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// CloneRefOfColIdent creates a deep clone of the input. -func CloneRefOfColIdent(n *ColIdent) *ColIdent { - if n == nil { +func (a *application) rewriteRefOfUnlockTables(parent SQLNode, node *UnlockTables, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - return &out + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitRefOfColIdent will visit all parts of the AST -func VisitRefOfColIdent(in *ColIdent, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*Update).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { + parent.(*Update).TableExprs = newNode.(TableExprs) + }); errF != nil { + return errF + } + if errF := a.rewriteUpdateExprs(node, node.Exprs, func(newNode, parent SQLNode) { + parent.(*Update).Exprs = newNode.(UpdateExprs) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*Update).Where = newNode.(*Where) + }); errF != nil { + return errF + } + if errF := a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { + parent.(*Update).OrderBy = newNode.(OrderBy) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*Update).Limit = newNode.(*Limit) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } } return nil } - -// EqualsColumnType does deep equals between the two objects. -func EqualsColumnType(a, b ColumnType) bool { - return a.Type == b.Type && - a.Unsigned == b.Unsigned && - a.Zerofill == b.Zerofill && - a.Charset == b.Charset && - a.Collate == b.Collate && - EqualsRefOfColumnTypeOptions(a.Options, b.Options) && - EqualsRefOfLiteral(a.Length, b.Length) && - EqualsRefOfLiteral(a.Scale, b.Scale) && - EqualsSliceOfString(a.EnumValues, b.EnumValues) +func (a *application) rewriteRefOfUpdateExpr(parent SQLNode, node *UpdateExpr, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*UpdateExpr).Name = newNode.(*ColName) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*UpdateExpr).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// CloneColumnType creates a deep clone of the input. -func CloneColumnType(n ColumnType) ColumnType { - return *CloneRefOfColumnType(&n) +func (a *application) rewriteRefOfUse(parent SQLNode, node *Use, replacer replacerFunc) error { + if node == nil { + return nil + } + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteTableIdent(node, node.DBName, func(newNode, parent SQLNode) { + parent.(*Use).DBName = newNode.(TableIdent) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsRefOfColumnTypeOptions does deep equals between the two objects. -func EqualsRefOfColumnTypeOptions(a, b *ColumnTypeOptions) bool { - if a == b { - return true +func (a *application) rewriteRefOfVStream(parent SQLNode, node *VStream, replacer replacerFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - return a.NotNull == b.NotNull && - a.Autoincrement == b.Autoincrement && - EqualsExpr(a.Default, b.Default) && - EqualsExpr(a.OnUpdate, b.OnUpdate) && - EqualsRefOfLiteral(a.Comment, b.Comment) && - a.KeyOpt == b.KeyOpt + if errF := a.rewriteComments(node, node.Comments, func(newNode, parent SQLNode) { + parent.(*VStream).Comments = newNode.(Comments) + }); errF != nil { + return errF + } + if errF := a.rewriteSelectExpr(node, node.SelectExpr, func(newNode, parent SQLNode) { + parent.(*VStream).SelectExpr = newNode.(SelectExpr) + }); errF != nil { + return errF + } + if errF := a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { + parent.(*VStream).Table = newNode.(TableName) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfWhere(node, node.Where, func(newNode, parent SQLNode) { + parent.(*VStream).Where = newNode.(*Where) + }); errF != nil { + return errF + } + if errF := a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { + parent.(*VStream).Limit = newNode.(*Limit) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// CloneRefOfColumnTypeOptions creates a deep clone of the input. -func CloneRefOfColumnTypeOptions(n *ColumnTypeOptions) *ColumnTypeOptions { - if n == nil { +func (a *application) rewriteRefOfValidation(parent SQLNode, node *Validation, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Default = CloneExpr(n.Default) - out.OnUpdate = CloneExpr(n.OnUpdate) - out.Comment = CloneRefOfLiteral(n.Comment) - return &out + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsSliceOfString does deep equals between the two objects. -func EqualsSliceOfString(a, b []string) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteRefOfValuesFuncExpr(parent SQLNode, node *ValuesFuncExpr, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteRefOfColName(node, node.Name, func(newNode, parent SQLNode) { + parent.(*ValuesFuncExpr).Name = newNode.(*ColName) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } } - return true -} - -// CloneSliceOfString creates a deep clone of the input. -func CloneSliceOfString(n []string) []string { - res := make([]string, 0, len(n)) - copy(res, n) - return res + return nil } - -// EqualsSliceOfRefOfIndexColumn does deep equals between the two objects. -func EqualsSliceOfRefOfIndexColumn(a, b []*IndexColumn) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteRefOfVindexParam(parent SQLNode, node *VindexParam, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfIndexColumn(a[i], b[i]) { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } } - return true -} - -// CloneSliceOfRefOfIndexColumn creates a deep clone of the input. -func CloneSliceOfRefOfIndexColumn(n []*IndexColumn) []*IndexColumn { - res := make([]*IndexColumn, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfIndexColumn(x)) + if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { + parent.(*VindexParam).Key = newNode.(ColIdent) + }); errF != nil { + return errF } - return res + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsSliceOfRefOfIndexOption does deep equals between the two objects. -func EqualsSliceOfRefOfIndexOption(a, b []*IndexOption) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfIndexOption(a[i], b[i]) { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } } - return true -} - -// CloneSliceOfRefOfIndexOption creates a deep clone of the input. -func CloneSliceOfRefOfIndexOption(n []*IndexOption) []*IndexOption { - res := make([]*IndexOption, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfIndexOption(x)) + if errF := a.rewriteColIdent(node, node.Name, func(newNode, parent SQLNode) { + parent.(*VindexSpec).Name = newNode.(ColIdent) + }); errF != nil { + return errF } - return res + if errF := a.rewriteColIdent(node, node.Type, func(newNode, parent SQLNode) { + parent.(*VindexSpec).Type = newNode.(ColIdent) + }); errF != nil { + return errF + } + for i, el := range node.Params { + if errF := a.rewriteVindexParam(node, el, func(newNode, parent SQLNode) { + parent.(*VindexSpec).Params[i] = newNode.(VindexParam) + }); errF != nil { + return errF + } + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// EqualsRefOfJoinCondition does deep equals between the two objects. -func EqualsRefOfJoinCondition(a, b *JoinCondition) bool { - if a == b { - return true +func (a *application) rewriteRefOfWhen(parent SQLNode, node *When, replacer replacerFunc) error { + if node == nil { + return nil } - if a == nil || b == nil { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - return EqualsExpr(a.On, b.On) && - EqualsColumns(a.Using, b.Using) + if errF := a.rewriteExpr(node, node.Cond, func(newNode, parent SQLNode) { + parent.(*When).Cond = newNode.(Expr) + }); errF != nil { + return errF + } + if errF := a.rewriteExpr(node, node.Val, func(newNode, parent SQLNode) { + parent.(*When).Val = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// CloneRefOfJoinCondition creates a deep clone of the input. -func CloneRefOfJoinCondition(n *JoinCondition) *JoinCondition { - if n == nil { +func (a *application) rewriteRefOfWhere(parent SQLNode, node *Where, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.On = CloneExpr(n.On) - out.Using = CloneColumns(n.Using) - return &out + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if errF := a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { + parent.(*Where).Expr = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// VisitRefOfJoinCondition will visit all parts of the AST -func VisitRefOfJoinCondition(in *JoinCondition, f Visit) error { - if in == nil { +func (a *application) rewriteRefOfXorExpr(parent SQLNode, node *XorExpr, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - if err := VisitExpr(in.On, f); err != nil { - return err + if errF := a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { + parent.(*XorExpr).Left = newNode.(Expr) + }); errF != nil { + return errF } - if err := VisitColumns(in.Using, f); err != nil { - return err + if errF := a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { + parent.(*XorExpr).Right = newNode.(Expr) + }); errF != nil { + return errF + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteReferenceAction(parent SQLNode, node ReferenceAction, replacer replacerFunc) error { + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } + } + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } + } + return nil +} +func (a *application) rewriteSQLNode(parent SQLNode, node SQLNode, replacer replacerFunc) error { + if node == nil { + return nil + } + switch node := node.(type) { + case AccessMode: + return a.rewriteAccessMode(parent, node, replacer) + case *AddColumns: + return a.rewriteRefOfAddColumns(parent, node, replacer) + case *AddConstraintDefinition: + return a.rewriteRefOfAddConstraintDefinition(parent, node, replacer) + case *AddIndexDefinition: + return a.rewriteRefOfAddIndexDefinition(parent, node, replacer) + case AlgorithmValue: + return a.rewriteAlgorithmValue(parent, node, replacer) + case *AliasedExpr: + return a.rewriteRefOfAliasedExpr(parent, node, replacer) + case *AliasedTableExpr: + return a.rewriteRefOfAliasedTableExpr(parent, node, replacer) + case *AlterCharset: + return a.rewriteRefOfAlterCharset(parent, node, replacer) + case *AlterColumn: + return a.rewriteRefOfAlterColumn(parent, node, replacer) + case *AlterDatabase: + return a.rewriteRefOfAlterDatabase(parent, node, replacer) + case *AlterMigration: + return a.rewriteRefOfAlterMigration(parent, node, replacer) + case *AlterTable: + return a.rewriteRefOfAlterTable(parent, node, replacer) + case *AlterView: + return a.rewriteRefOfAlterView(parent, node, replacer) + case *AlterVschema: + return a.rewriteRefOfAlterVschema(parent, node, replacer) + case *AndExpr: + return a.rewriteRefOfAndExpr(parent, node, replacer) + case Argument: + return a.rewriteArgument(parent, node, replacer) + case *AutoIncSpec: + return a.rewriteRefOfAutoIncSpec(parent, node, replacer) + case *Begin: + return a.rewriteRefOfBegin(parent, node, replacer) + case *BinaryExpr: + return a.rewriteRefOfBinaryExpr(parent, node, replacer) + case BoolVal: + return a.rewriteBoolVal(parent, node, replacer) + case *CallProc: + return a.rewriteRefOfCallProc(parent, node, replacer) + case *CaseExpr: + return a.rewriteRefOfCaseExpr(parent, node, replacer) + case *ChangeColumn: + return a.rewriteRefOfChangeColumn(parent, node, replacer) + case *CheckConstraintDefinition: + return a.rewriteRefOfCheckConstraintDefinition(parent, node, replacer) + case ColIdent: + return a.rewriteColIdent(parent, node, replacer) + case *ColName: + return a.rewriteRefOfColName(parent, node, replacer) + case *CollateExpr: + return a.rewriteRefOfCollateExpr(parent, node, replacer) + case *ColumnDefinition: + return a.rewriteRefOfColumnDefinition(parent, node, replacer) + case *ColumnType: + return a.rewriteRefOfColumnType(parent, node, replacer) + case Columns: + return a.rewriteColumns(parent, node, replacer) + case Comments: + return a.rewriteComments(parent, node, replacer) + case *Commit: + return a.rewriteRefOfCommit(parent, node, replacer) + case *ComparisonExpr: + return a.rewriteRefOfComparisonExpr(parent, node, replacer) + case *ConstraintDefinition: + return a.rewriteRefOfConstraintDefinition(parent, node, replacer) + case *ConvertExpr: + return a.rewriteRefOfConvertExpr(parent, node, replacer) + case *ConvertType: + return a.rewriteRefOfConvertType(parent, node, replacer) + case *ConvertUsingExpr: + return a.rewriteRefOfConvertUsingExpr(parent, node, replacer) + case *CreateDatabase: + return a.rewriteRefOfCreateDatabase(parent, node, replacer) + case *CreateTable: + return a.rewriteRefOfCreateTable(parent, node, replacer) + case *CreateView: + return a.rewriteRefOfCreateView(parent, node, replacer) + case *CurTimeFuncExpr: + return a.rewriteRefOfCurTimeFuncExpr(parent, node, replacer) + case *Default: + return a.rewriteRefOfDefault(parent, node, replacer) + case *Delete: + return a.rewriteRefOfDelete(parent, node, replacer) + case *DerivedTable: + return a.rewriteRefOfDerivedTable(parent, node, replacer) + case *DropColumn: + return a.rewriteRefOfDropColumn(parent, node, replacer) + case *DropDatabase: + return a.rewriteRefOfDropDatabase(parent, node, replacer) + case *DropKey: + return a.rewriteRefOfDropKey(parent, node, replacer) + case *DropTable: + return a.rewriteRefOfDropTable(parent, node, replacer) + case *DropView: + return a.rewriteRefOfDropView(parent, node, replacer) + case *ExistsExpr: + return a.rewriteRefOfExistsExpr(parent, node, replacer) + case *ExplainStmt: + return a.rewriteRefOfExplainStmt(parent, node, replacer) + case *ExplainTab: + return a.rewriteRefOfExplainTab(parent, node, replacer) + case Exprs: + return a.rewriteExprs(parent, node, replacer) + case *Flush: + return a.rewriteRefOfFlush(parent, node, replacer) + case *Force: + return a.rewriteRefOfForce(parent, node, replacer) + case *ForeignKeyDefinition: + return a.rewriteRefOfForeignKeyDefinition(parent, node, replacer) + case *FuncExpr: + return a.rewriteRefOfFuncExpr(parent, node, replacer) + case GroupBy: + return a.rewriteGroupBy(parent, node, replacer) + case *GroupConcatExpr: + return a.rewriteRefOfGroupConcatExpr(parent, node, replacer) + case *IndexDefinition: + return a.rewriteRefOfIndexDefinition(parent, node, replacer) + case *IndexHints: + return a.rewriteRefOfIndexHints(parent, node, replacer) + case *IndexInfo: + return a.rewriteRefOfIndexInfo(parent, node, replacer) + case *Insert: + return a.rewriteRefOfInsert(parent, node, replacer) + case *IntervalExpr: + return a.rewriteRefOfIntervalExpr(parent, node, replacer) + case *IsExpr: + return a.rewriteRefOfIsExpr(parent, node, replacer) + case IsolationLevel: + return a.rewriteIsolationLevel(parent, node, replacer) + case JoinCondition: + return a.rewriteJoinCondition(parent, node, replacer) + case *JoinTableExpr: + return a.rewriteRefOfJoinTableExpr(parent, node, replacer) + case *KeyState: + return a.rewriteRefOfKeyState(parent, node, replacer) + case *Limit: + return a.rewriteRefOfLimit(parent, node, replacer) + case ListArg: + return a.rewriteListArg(parent, node, replacer) + case *Literal: + return a.rewriteRefOfLiteral(parent, node, replacer) + case *Load: + return a.rewriteRefOfLoad(parent, node, replacer) + case *LockOption: + return a.rewriteRefOfLockOption(parent, node, replacer) + case *LockTables: + return a.rewriteRefOfLockTables(parent, node, replacer) + case *MatchExpr: + return a.rewriteRefOfMatchExpr(parent, node, replacer) + case *ModifyColumn: + return a.rewriteRefOfModifyColumn(parent, node, replacer) + case *Nextval: + return a.rewriteRefOfNextval(parent, node, replacer) + case *NotExpr: + return a.rewriteRefOfNotExpr(parent, node, replacer) + case *NullVal: + return a.rewriteRefOfNullVal(parent, node, replacer) + case OnDup: + return a.rewriteOnDup(parent, node, replacer) + case *OptLike: + return a.rewriteRefOfOptLike(parent, node, replacer) + case *OrExpr: + return a.rewriteRefOfOrExpr(parent, node, replacer) + case *Order: + return a.rewriteRefOfOrder(parent, node, replacer) + case OrderBy: + return a.rewriteOrderBy(parent, node, replacer) + case *OrderByOption: + return a.rewriteRefOfOrderByOption(parent, node, replacer) + case *OtherAdmin: + return a.rewriteRefOfOtherAdmin(parent, node, replacer) + case *OtherRead: + return a.rewriteRefOfOtherRead(parent, node, replacer) + case *ParenSelect: + return a.rewriteRefOfParenSelect(parent, node, replacer) + case *ParenTableExpr: + return a.rewriteRefOfParenTableExpr(parent, node, replacer) + case *PartitionDefinition: + return a.rewriteRefOfPartitionDefinition(parent, node, replacer) + case *PartitionSpec: + return a.rewriteRefOfPartitionSpec(parent, node, replacer) + case Partitions: + return a.rewritePartitions(parent, node, replacer) + case *RangeCond: + return a.rewriteRefOfRangeCond(parent, node, replacer) + case ReferenceAction: + return a.rewriteReferenceAction(parent, node, replacer) + case *Release: + return a.rewriteRefOfRelease(parent, node, replacer) + case *RenameIndex: + return a.rewriteRefOfRenameIndex(parent, node, replacer) + case *RenameTable: + return a.rewriteRefOfRenameTable(parent, node, replacer) + case *RenameTableName: + return a.rewriteRefOfRenameTableName(parent, node, replacer) + case *RevertMigration: + return a.rewriteRefOfRevertMigration(parent, node, replacer) + case *Rollback: + return a.rewriteRefOfRollback(parent, node, replacer) + case *SRollback: + return a.rewriteRefOfSRollback(parent, node, replacer) + case *Savepoint: + return a.rewriteRefOfSavepoint(parent, node, replacer) + case *Select: + return a.rewriteRefOfSelect(parent, node, replacer) + case SelectExprs: + return a.rewriteSelectExprs(parent, node, replacer) + case *SelectInto: + return a.rewriteRefOfSelectInto(parent, node, replacer) + case *Set: + return a.rewriteRefOfSet(parent, node, replacer) + case *SetExpr: + return a.rewriteRefOfSetExpr(parent, node, replacer) + case SetExprs: + return a.rewriteSetExprs(parent, node, replacer) + case *SetTransaction: + return a.rewriteRefOfSetTransaction(parent, node, replacer) + case *Show: + return a.rewriteRefOfShow(parent, node, replacer) + case *ShowBasic: + return a.rewriteRefOfShowBasic(parent, node, replacer) + case *ShowCreate: + return a.rewriteRefOfShowCreate(parent, node, replacer) + case *ShowFilter: + return a.rewriteRefOfShowFilter(parent, node, replacer) + case *ShowLegacy: + return a.rewriteRefOfShowLegacy(parent, node, replacer) + case *StarExpr: + return a.rewriteRefOfStarExpr(parent, node, replacer) + case *Stream: + return a.rewriteRefOfStream(parent, node, replacer) + case *Subquery: + return a.rewriteRefOfSubquery(parent, node, replacer) + case *SubstrExpr: + return a.rewriteRefOfSubstrExpr(parent, node, replacer) + case TableExprs: + return a.rewriteTableExprs(parent, node, replacer) + case TableIdent: + return a.rewriteTableIdent(parent, node, replacer) + case TableName: + return a.rewriteTableName(parent, node, replacer) + case TableNames: + return a.rewriteTableNames(parent, node, replacer) + case TableOptions: + return a.rewriteTableOptions(parent, node, replacer) + case *TableSpec: + return a.rewriteRefOfTableSpec(parent, node, replacer) + case *TablespaceOperation: + return a.rewriteRefOfTablespaceOperation(parent, node, replacer) + case *TimestampFuncExpr: + return a.rewriteRefOfTimestampFuncExpr(parent, node, replacer) + case *TruncateTable: + return a.rewriteRefOfTruncateTable(parent, node, replacer) + case *UnaryExpr: + return a.rewriteRefOfUnaryExpr(parent, node, replacer) + case *Union: + return a.rewriteRefOfUnion(parent, node, replacer) + case *UnionSelect: + return a.rewriteRefOfUnionSelect(parent, node, replacer) + case *UnlockTables: + return a.rewriteRefOfUnlockTables(parent, node, replacer) + case *Update: + return a.rewriteRefOfUpdate(parent, node, replacer) + case *UpdateExpr: + return a.rewriteRefOfUpdateExpr(parent, node, replacer) + case UpdateExprs: + return a.rewriteUpdateExprs(parent, node, replacer) + case *Use: + return a.rewriteRefOfUse(parent, node, replacer) + case *VStream: + return a.rewriteRefOfVStream(parent, node, replacer) + case ValTuple: + return a.rewriteValTuple(parent, node, replacer) + case *Validation: + return a.rewriteRefOfValidation(parent, node, replacer) + case Values: + return a.rewriteValues(parent, node, replacer) + case *ValuesFuncExpr: + return a.rewriteRefOfValuesFuncExpr(parent, node, replacer) + case VindexParam: + return a.rewriteVindexParam(parent, node, replacer) + case *VindexSpec: + return a.rewriteRefOfVindexSpec(parent, node, replacer) + case *When: + return a.rewriteRefOfWhen(parent, node, replacer) + case *Where: + return a.rewriteRefOfWhere(parent, node, replacer) + case *XorExpr: + return a.rewriteRefOfXorExpr(parent, node, replacer) + default: + // this should never happen + return nil } - return nil } - -// EqualsTableAndLockTypes does deep equals between the two objects. -func EqualsTableAndLockTypes(a, b TableAndLockTypes) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfTableAndLockType(a[i], b[i]) { - return false - } +func (a *application) rewriteSelectExpr(parent SQLNode, node SelectExpr, replacer replacerFunc) error { + if node == nil { + return nil } - return true -} - -// CloneTableAndLockTypes creates a deep clone of the input. -func CloneTableAndLockTypes(n TableAndLockTypes) TableAndLockTypes { - res := make(TableAndLockTypes, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfTableAndLockType(x)) + switch node := node.(type) { + case *AliasedExpr: + return a.rewriteRefOfAliasedExpr(parent, node, replacer) + case *Nextval: + return a.rewriteRefOfNextval(parent, node, replacer) + case *StarExpr: + return a.rewriteRefOfStarExpr(parent, node, replacer) + default: + // this should never happen + return nil } - return res } - -// EqualsSliceOfRefOfPartitionDefinition does deep equals between the two objects. -func EqualsSliceOfRefOfPartitionDefinition(a, b []*PartitionDefinition) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfPartitionDefinition(a[i], b[i]) { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } } - return true -} - -// CloneSliceOfRefOfPartitionDefinition creates a deep clone of the input. -func CloneSliceOfRefOfPartitionDefinition(n []*PartitionDefinition) []*PartitionDefinition { - res := make([]*PartitionDefinition, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfPartitionDefinition(x)) - } - return res -} - -// EqualsSliceOfRefOfRenameTablePair does deep equals between the two objects. -func EqualsSliceOfRefOfRenameTablePair(a, b []*RenameTablePair) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfRenameTablePair(a[i], b[i]) { - return false + for i, el := range node { + if errF := a.rewriteSelectExpr(node, el, func(newNode, parent SQLNode) { + parent.(SelectExprs)[i] = newNode.(SelectExpr) + }); errF != nil { + return errF } } - return true -} - -// CloneSliceOfRefOfRenameTablePair creates a deep clone of the input. -func CloneSliceOfRefOfRenameTablePair(n []*RenameTablePair) []*RenameTablePair { - res := make([]*RenameTablePair, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfRenameTablePair(x)) + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } } - return res + return nil } - -// EqualsRefOfBool does deep equals between the two objects. -func EqualsRefOfBool(a, b *bool) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false +func (a *application) rewriteSelectStatement(parent SQLNode, node SelectStatement, replacer replacerFunc) error { + if node == nil { + return nil } - return *a == *b -} - -// CloneRefOfBool creates a deep clone of the input. -func CloneRefOfBool(n *bool) *bool { - if n == nil { + switch node := node.(type) { + case *ParenSelect: + return a.rewriteRefOfParenSelect(parent, node, replacer) + case *Select: + return a.rewriteRefOfSelect(parent, node, replacer) + case *Union: + return a.rewriteRefOfUnion(parent, node, replacer) + default: + // this should never happen return nil } - out := *n - return &out } - -// EqualsSliceOfCharacteristic does deep equals between the two objects. -func EqualsSliceOfCharacteristic(a, b []Characteristic) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsCharacteristic(a[i], b[i]) { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } } - return true -} - -// CloneSliceOfCharacteristic creates a deep clone of the input. -func CloneSliceOfCharacteristic(n []Characteristic) []Characteristic { - res := make([]Characteristic, 0, len(n)) - for _, x := range n { - res = append(res, CloneCharacteristic(x)) - } - return res -} - -// EqualsRefOfShowTablesOpt does deep equals between the two objects. -func EqualsRefOfShowTablesOpt(a, b *ShowTablesOpt) bool { - if a == b { - return true + for i, el := range node { + if errF := a.rewriteRefOfSetExpr(node, el, func(newNode, parent SQLNode) { + parent.(SetExprs)[i] = newNode.(*SetExpr) + }); errF != nil { + return errF + } } - if a == nil || b == nil { - return false + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } } - return a.Full == b.Full && - a.DbName == b.DbName && - EqualsRefOfShowFilter(a.Filter, b.Filter) + return nil } - -// CloneRefOfShowTablesOpt creates a deep clone of the input. -func CloneRefOfShowTablesOpt(n *ShowTablesOpt) *ShowTablesOpt { - if n == nil { +func (a *application) rewriteShowInternal(parent SQLNode, node ShowInternal, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Filter = CloneRefOfShowFilter(n.Filter) - return &out -} - -// EqualsRefOfTableIdent does deep equals between the two objects. -func EqualsRefOfTableIdent(a, b *TableIdent) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.v == b.v -} - -// CloneRefOfTableIdent creates a deep clone of the input. -func CloneRefOfTableIdent(n *TableIdent) *TableIdent { - if n == nil { + switch node := node.(type) { + case *ShowBasic: + return a.rewriteRefOfShowBasic(parent, node, replacer) + case *ShowCreate: + return a.rewriteRefOfShowCreate(parent, node, replacer) + case *ShowLegacy: + return a.rewriteRefOfShowLegacy(parent, node, replacer) + default: + // this should never happen return nil } - out := *n - return &out } - -// VisitRefOfTableIdent will visit all parts of the AST -func VisitRefOfTableIdent(in *TableIdent, f Visit) error { - if in == nil { +func (a *application) rewriteSimpleTableExpr(parent SQLNode, node SimpleTableExpr, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - return nil -} - -// EqualsRefOfTableName does deep equals between the two objects. -func EqualsRefOfTableName(a, b *TableName) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return EqualsTableIdent(a.Name, b.Name) && - EqualsTableIdent(a.Qualifier, b.Qualifier) -} - -// CloneRefOfTableName creates a deep clone of the input. -func CloneRefOfTableName(n *TableName) *TableName { - if n == nil { + switch node := node.(type) { + case *DerivedTable: + return a.rewriteRefOfDerivedTable(parent, node, replacer) + case TableName: + return a.rewriteTableName(parent, node, replacer) + default: + // this should never happen return nil } - out := *n - out.Name = CloneTableIdent(n.Name) - out.Qualifier = CloneTableIdent(n.Qualifier) - return &out } - -// VisitRefOfTableName will visit all parts of the AST -func VisitRefOfTableName(in *TableName, f Visit) error { - if in == nil { +func (a *application) rewriteStatement(parent SQLNode, node Statement, replacer replacerFunc) error { + if node == nil { return nil } - if cont, err := f(in); err != nil || !cont { - return err - } - if err := VisitTableIdent(in.Name, f); err != nil { - return err - } - if err := VisitTableIdent(in.Qualifier, f); err != nil { - return err - } - return nil -} - -// EqualsRefOfTableOption does deep equals between the two objects. -func EqualsRefOfTableOption(a, b *TableOption) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - return a.Name == b.Name && - a.String == b.String && - EqualsRefOfLiteral(a.Value, b.Value) && - EqualsTableNames(a.Tables, b.Tables) -} - -// CloneRefOfTableOption creates a deep clone of the input. -func CloneRefOfTableOption(n *TableOption) *TableOption { - if n == nil { + switch node := node.(type) { + case *AlterDatabase: + return a.rewriteRefOfAlterDatabase(parent, node, replacer) + case *AlterMigration: + return a.rewriteRefOfAlterMigration(parent, node, replacer) + case *AlterTable: + return a.rewriteRefOfAlterTable(parent, node, replacer) + case *AlterView: + return a.rewriteRefOfAlterView(parent, node, replacer) + case *AlterVschema: + return a.rewriteRefOfAlterVschema(parent, node, replacer) + case *Begin: + return a.rewriteRefOfBegin(parent, node, replacer) + case *CallProc: + return a.rewriteRefOfCallProc(parent, node, replacer) + case *Commit: + return a.rewriteRefOfCommit(parent, node, replacer) + case *CreateDatabase: + return a.rewriteRefOfCreateDatabase(parent, node, replacer) + case *CreateTable: + return a.rewriteRefOfCreateTable(parent, node, replacer) + case *CreateView: + return a.rewriteRefOfCreateView(parent, node, replacer) + case *Delete: + return a.rewriteRefOfDelete(parent, node, replacer) + case *DropDatabase: + return a.rewriteRefOfDropDatabase(parent, node, replacer) + case *DropTable: + return a.rewriteRefOfDropTable(parent, node, replacer) + case *DropView: + return a.rewriteRefOfDropView(parent, node, replacer) + case *ExplainStmt: + return a.rewriteRefOfExplainStmt(parent, node, replacer) + case *ExplainTab: + return a.rewriteRefOfExplainTab(parent, node, replacer) + case *Flush: + return a.rewriteRefOfFlush(parent, node, replacer) + case *Insert: + return a.rewriteRefOfInsert(parent, node, replacer) + case *Load: + return a.rewriteRefOfLoad(parent, node, replacer) + case *LockTables: + return a.rewriteRefOfLockTables(parent, node, replacer) + case *OtherAdmin: + return a.rewriteRefOfOtherAdmin(parent, node, replacer) + case *OtherRead: + return a.rewriteRefOfOtherRead(parent, node, replacer) + case *ParenSelect: + return a.rewriteRefOfParenSelect(parent, node, replacer) + case *Release: + return a.rewriteRefOfRelease(parent, node, replacer) + case *RenameTable: + return a.rewriteRefOfRenameTable(parent, node, replacer) + case *RevertMigration: + return a.rewriteRefOfRevertMigration(parent, node, replacer) + case *Rollback: + return a.rewriteRefOfRollback(parent, node, replacer) + case *SRollback: + return a.rewriteRefOfSRollback(parent, node, replacer) + case *Savepoint: + return a.rewriteRefOfSavepoint(parent, node, replacer) + case *Select: + return a.rewriteRefOfSelect(parent, node, replacer) + case *Set: + return a.rewriteRefOfSet(parent, node, replacer) + case *SetTransaction: + return a.rewriteRefOfSetTransaction(parent, node, replacer) + case *Show: + return a.rewriteRefOfShow(parent, node, replacer) + case *Stream: + return a.rewriteRefOfStream(parent, node, replacer) + case *TruncateTable: + return a.rewriteRefOfTruncateTable(parent, node, replacer) + case *Union: + return a.rewriteRefOfUnion(parent, node, replacer) + case *UnlockTables: + return a.rewriteRefOfUnlockTables(parent, node, replacer) + case *Update: + return a.rewriteRefOfUpdate(parent, node, replacer) + case *Use: + return a.rewriteRefOfUse(parent, node, replacer) + case *VStream: + return a.rewriteRefOfVStream(parent, node, replacer) + default: + // this should never happen return nil } - out := *n - out.Value = CloneRefOfLiteral(n.Value) - out.Tables = CloneTableNames(n.Tables) - return &out } - -// EqualsSliceOfRefOfIndexDefinition does deep equals between the two objects. -func EqualsSliceOfRefOfIndexDefinition(a, b []*IndexDefinition) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !EqualsRefOfIndexDefinition(a[i], b[i]) { - return false - } +func (a *application) rewriteTableExpr(parent SQLNode, node TableExpr, replacer replacerFunc) error { + if node == nil { + return nil } - return true -} - -// CloneSliceOfRefOfIndexDefinition creates a deep clone of the input. -func CloneSliceOfRefOfIndexDefinition(n []*IndexDefinition) []*IndexDefinition { - res := make([]*IndexDefinition, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfIndexDefinition(x)) + switch node := node.(type) { + case *AliasedTableExpr: + return a.rewriteRefOfAliasedTableExpr(parent, node, replacer) + case *JoinTableExpr: + return a.rewriteRefOfJoinTableExpr(parent, node, replacer) + case *ParenTableExpr: + return a.rewriteRefOfParenTableExpr(parent, node, replacer) + default: + // this should never happen + return nil } - return res } - -// EqualsSliceOfRefOfConstraintDefinition does deep equals between the two objects. -func EqualsSliceOfRefOfConstraintDefinition(a, b []*ConstraintDefinition) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsRefOfConstraintDefinition(a[i], b[i]) { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } } - return true -} - -// CloneSliceOfRefOfConstraintDefinition creates a deep clone of the input. -func CloneSliceOfRefOfConstraintDefinition(n []*ConstraintDefinition) []*ConstraintDefinition { - res := make([]*ConstraintDefinition, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfConstraintDefinition(x)) - } - return res -} - -// EqualsSliceOfRefOfUnionSelect does deep equals between the two objects. -func EqualsSliceOfRefOfUnionSelect(a, b []*UnionSelect) bool { - if len(a) != len(b) { - return false + for i, el := range node { + if errF := a.rewriteTableExpr(node, el, func(newNode, parent SQLNode) { + parent.(TableExprs)[i] = newNode.(TableExpr) + }); errF != nil { + return errF + } } - for i := 0; i < len(a); i++ { - if !EqualsRefOfUnionSelect(a[i], b[i]) { - return false + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort } } - return true + return nil } - -// CloneSliceOfRefOfUnionSelect creates a deep clone of the input. -func CloneSliceOfRefOfUnionSelect(n []*UnionSelect) []*UnionSelect { - res := make([]*UnionSelect, 0, len(n)) - for _, x := range n { - res = append(res, CloneRefOfUnionSelect(x)) +func (a *application) rewriteTableIdent(parent SQLNode, node TableIdent, replacer replacerFunc) error { + var err error + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - return res -} - -// EqualsRefOfVindexParam does deep equals between the two objects. -func EqualsRefOfVindexParam(a, b *VindexParam) bool { - if a == b { - return true + if err != nil { + return err } - if a == nil || b == nil { - return false + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } - return a.Val == b.Val && - EqualsColIdent(a.Key, b.Key) + return nil } - -// CloneRefOfVindexParam creates a deep clone of the input. -func CloneRefOfVindexParam(n *VindexParam) *VindexParam { - if n == nil { - return nil +func (a *application) rewriteTableName(parent SQLNode, node TableName, replacer replacerFunc) error { + var err error + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - out := *n - out.Key = CloneColIdent(n.Key) - return &out -} - -// VisitRefOfVindexParam will visit all parts of the AST -func VisitRefOfVindexParam(in *VindexParam, f Visit) error { - if in == nil { - return nil + if errF := a.rewriteTableIdent(node, node.Name, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Name' on 'TableName'") + }); errF != nil { + return errF } - if cont, err := f(in); err != nil || !cont { - return err + if errF := a.rewriteTableIdent(node, node.Qualifier, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Qualifier' on 'TableName'") + }); errF != nil { + return errF } - if err := VisitColIdent(in.Key, f); err != nil { + if err != nil { return err } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } return nil } - -// EqualsSliceOfVindexParam does deep equals between the two objects. -func EqualsSliceOfVindexParam(a, b []VindexParam) bool { - if len(a) != len(b) { - return false +func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replacer replacerFunc) error { + if node == nil { + return nil } - for i := 0; i < len(a); i++ { - if !EqualsVindexParam(a[i], b[i]) { - return false + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil } } - return true -} - -// CloneSliceOfVindexParam creates a deep clone of the input. -func CloneSliceOfVindexParam(n []VindexParam) []VindexParam { - res := make([]VindexParam, 0, len(n)) - for _, x := range n { - res = append(res, CloneVindexParam(x)) - } - return res -} - -// EqualsCollateAndCharset does deep equals between the two objects. -func EqualsCollateAndCharset(a, b CollateAndCharset) bool { - return a.IsDefault == b.IsDefault && - a.Value == b.Value && - a.Type == b.Type -} - -// CloneCollateAndCharset creates a deep clone of the input. -func CloneCollateAndCharset(n CollateAndCharset) CollateAndCharset { - return *CloneRefOfCollateAndCharset(&n) -} - -// EqualsRefOfIndexColumn does deep equals between the two objects. -func EqualsRefOfIndexColumn(a, b *IndexColumn) bool { - if a == b { - return true + for i, el := range node { + if errF := a.rewriteTableName(node, el, func(newNode, parent SQLNode) { + parent.(TableNames)[i] = newNode.(TableName) + }); errF != nil { + return errF + } } - if a == nil || b == nil { - return false + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } } - return EqualsColIdent(a.Column, b.Column) && - EqualsRefOfLiteral(a.Length, b.Length) && - a.Direction == b.Direction + return nil } - -// CloneRefOfIndexColumn creates a deep clone of the input. -func CloneRefOfIndexColumn(n *IndexColumn) *IndexColumn { - if n == nil { +func (a *application) rewriteTableOptions(parent SQLNode, node TableOptions, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Column = CloneColIdent(n.Column) - out.Length = CloneRefOfLiteral(n.Length) - return &out -} - -// EqualsRefOfIndexOption does deep equals between the two objects. -func EqualsRefOfIndexOption(a, b *IndexOption) bool { - if a == b { - return true + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - if a == nil || b == nil { - return false + if a.post != nil { + if a.pre == nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + } + if !a.post(&a.cur) { + return errAbort + } } - return a.Name == b.Name && - a.String == b.String && - EqualsRefOfLiteral(a.Value, b.Value) + return nil } - -// CloneRefOfIndexOption creates a deep clone of the input. -func CloneRefOfIndexOption(n *IndexOption) *IndexOption { - if n == nil { +func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Value = CloneRefOfLiteral(n.Value) - return &out -} - -// EqualsRefOfTableAndLockType does deep equals between the two objects. -func EqualsRefOfTableAndLockType(a, b *TableAndLockType) bool { - if a == b { - return true + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - if a == nil || b == nil { - return false + for i, el := range node { + if errF := a.rewriteRefOfUpdateExpr(node, el, func(newNode, parent SQLNode) { + parent.(UpdateExprs)[i] = newNode.(*UpdateExpr) + }); errF != nil { + return errF + } } - return EqualsTableExpr(a.Table, b.Table) && - a.Lock == b.Lock + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// CloneRefOfTableAndLockType creates a deep clone of the input. -func CloneRefOfTableAndLockType(n *TableAndLockType) *TableAndLockType { - if n == nil { +func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.Table = CloneTableExpr(n.Table) - return &out -} - -// EqualsRefOfRenameTablePair does deep equals between the two objects. -func EqualsRefOfRenameTablePair(a, b *RenameTablePair) bool { - if a == b { - return true + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - if a == nil || b == nil { - return false + for i, el := range node { + if errF := a.rewriteExpr(node, el, func(newNode, parent SQLNode) { + parent.(ValTuple)[i] = newNode.(Expr) + }); errF != nil { + return errF + } } - return EqualsTableName(a.FromTable, b.FromTable) && - EqualsTableName(a.ToTable, b.ToTable) + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// CloneRefOfRenameTablePair creates a deep clone of the input. -func CloneRefOfRenameTablePair(n *RenameTablePair) *RenameTablePair { - if n == nil { +func (a *application) rewriteValues(parent SQLNode, node Values, replacer replacerFunc) error { + if node == nil { return nil } - out := *n - out.FromTable = CloneTableName(n.FromTable) - out.ToTable = CloneTableName(n.ToTable) - return &out -} - -// EqualsRefOfCollateAndCharset does deep equals between the two objects. -func EqualsRefOfCollateAndCharset(a, b *CollateAndCharset) bool { - if a == b { - return true + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - if a == nil || b == nil { - return false + for i, el := range node { + if errF := a.rewriteValTuple(node, el, func(newNode, parent SQLNode) { + parent.(Values)[i] = newNode.(ValTuple) + }); errF != nil { + return errF + } } - return a.IsDefault == b.IsDefault && - a.Value == b.Value && - a.Type == b.Type + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } - -// CloneRefOfCollateAndCharset creates a deep clone of the input. -func CloneRefOfCollateAndCharset(n *CollateAndCharset) *CollateAndCharset { - if n == nil { - return nil +func (a *application) rewriteVindexParam(parent SQLNode, node VindexParam, replacer replacerFunc) error { + var err error + if a.pre != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.pre(&a.cur) { + return nil + } } - out := *n - return &out + if errF := a.rewriteColIdent(node, node.Key, func(newNode, parent SQLNode) { + err = vterrors.New(vtrpc.Code_INTERNAL, "[BUG] tried to replace 'Key' on 'VindexParam'") + }); errF != nil { + return errF + } + if err != nil { + return err + } + if a.post != nil { + a.cur.replacer = replacer + a.cur.parent = parent + a.cur.node = node + if !a.post(&a.cur) { + return errAbort + } + } + return nil } diff --git a/go/vt/sqlparser/rewriter.go b/go/vt/sqlparser/rewriter.go deleted file mode 100644 index ccffe468745..00000000000 --- a/go/vt/sqlparser/rewriter.go +++ /dev/null @@ -1,931 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -// Code generated by ASTHelperGen. DO NOT EDIT. - -package sqlparser - -func (a *application) apply(parent, node SQLNode, replacer replacerFunc) { - if node == nil || isNilValue(node) { - return - } - saved := a.cursor - a.cursor.replacer = replacer - a.cursor.node = node - a.cursor.parent = parent - if a.pre != nil && !a.pre(&a.cursor) { - a.cursor = saved - return - } - switch n := node.(type) { - case *AddColumns: - for x, el := range n.Columns { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*AddColumns).Columns[idx] = newNode.(*ColumnDefinition) - } - }(x)) - } - a.apply(node, n.First, func(newNode, parent SQLNode) { - parent.(*AddColumns).First = newNode.(*ColName) - }) - a.apply(node, n.After, func(newNode, parent SQLNode) { - parent.(*AddColumns).After = newNode.(*ColName) - }) - case *AddConstraintDefinition: - a.apply(node, n.ConstraintDefinition, func(newNode, parent SQLNode) { - parent.(*AddConstraintDefinition).ConstraintDefinition = newNode.(*ConstraintDefinition) - }) - case *AddIndexDefinition: - a.apply(node, n.IndexDefinition, func(newNode, parent SQLNode) { - parent.(*AddIndexDefinition).IndexDefinition = newNode.(*IndexDefinition) - }) - case *AliasedExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*AliasedExpr).Expr = newNode.(Expr) - }) - a.apply(node, n.As, func(newNode, parent SQLNode) { - parent.(*AliasedExpr).As = newNode.(ColIdent) - }) - case *AliasedTableExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Expr = newNode.(SimpleTableExpr) - }) - a.apply(node, n.Partitions, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Partitions = newNode.(Partitions) - }) - a.apply(node, n.As, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).As = newNode.(TableIdent) - }) - a.apply(node, n.Hints, func(newNode, parent SQLNode) { - parent.(*AliasedTableExpr).Hints = newNode.(*IndexHints) - }) - case *AlterCharset: - case *AlterColumn: - a.apply(node, n.Column, func(newNode, parent SQLNode) { - parent.(*AlterColumn).Column = newNode.(*ColName) - }) - a.apply(node, n.DefaultVal, func(newNode, parent SQLNode) { - parent.(*AlterColumn).DefaultVal = newNode.(Expr) - }) - case *AlterDatabase: - case *AlterMigration: - case *AlterTable: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*AlterTable).Table = newNode.(TableName) - }) - for x, el := range n.AlterOptions { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*AlterTable).AlterOptions[idx] = newNode.(AlterOption) - } - }(x)) - } - a.apply(node, n.PartitionSpec, func(newNode, parent SQLNode) { - parent.(*AlterTable).PartitionSpec = newNode.(*PartitionSpec) - }) - case *AlterView: - a.apply(node, n.ViewName, func(newNode, parent SQLNode) { - parent.(*AlterView).ViewName = newNode.(TableName) - }) - a.apply(node, n.Columns, func(newNode, parent SQLNode) { - parent.(*AlterView).Columns = newNode.(Columns) - }) - a.apply(node, n.Select, func(newNode, parent SQLNode) { - parent.(*AlterView).Select = newNode.(SelectStatement) - }) - case *AlterVschema: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*AlterVschema).Table = newNode.(TableName) - }) - a.apply(node, n.VindexSpec, func(newNode, parent SQLNode) { - parent.(*AlterVschema).VindexSpec = newNode.(*VindexSpec) - }) - for x, el := range n.VindexCols { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*AlterVschema).VindexCols[idx] = newNode.(ColIdent) - } - }(x)) - } - a.apply(node, n.AutoIncSpec, func(newNode, parent SQLNode) { - parent.(*AlterVschema).AutoIncSpec = newNode.(*AutoIncSpec) - }) - case *AndExpr: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*AndExpr).Left = newNode.(Expr) - }) - a.apply(node, n.Right, func(newNode, parent SQLNode) { - parent.(*AndExpr).Right = newNode.(Expr) - }) - case *AutoIncSpec: - a.apply(node, n.Column, func(newNode, parent SQLNode) { - parent.(*AutoIncSpec).Column = newNode.(ColIdent) - }) - a.apply(node, n.Sequence, func(newNode, parent SQLNode) { - parent.(*AutoIncSpec).Sequence = newNode.(TableName) - }) - case *Begin: - case *BinaryExpr: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*BinaryExpr).Left = newNode.(Expr) - }) - a.apply(node, n.Right, func(newNode, parent SQLNode) { - parent.(*BinaryExpr).Right = newNode.(Expr) - }) - case *CallProc: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*CallProc).Name = newNode.(TableName) - }) - a.apply(node, n.Params, func(newNode, parent SQLNode) { - parent.(*CallProc).Params = newNode.(Exprs) - }) - case *CaseExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*CaseExpr).Expr = newNode.(Expr) - }) - for x, el := range n.Whens { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*CaseExpr).Whens[idx] = newNode.(*When) - } - }(x)) - } - a.apply(node, n.Else, func(newNode, parent SQLNode) { - parent.(*CaseExpr).Else = newNode.(Expr) - }) - case *ChangeColumn: - a.apply(node, n.OldColumn, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).OldColumn = newNode.(*ColName) - }) - a.apply(node, n.NewColDefinition, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).NewColDefinition = newNode.(*ColumnDefinition) - }) - a.apply(node, n.First, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).First = newNode.(*ColName) - }) - a.apply(node, n.After, func(newNode, parent SQLNode) { - parent.(*ChangeColumn).After = newNode.(*ColName) - }) - case *CheckConstraintDefinition: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*CheckConstraintDefinition).Expr = newNode.(Expr) - }) - case ColIdent: - case *ColName: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*ColName).Name = newNode.(ColIdent) - }) - a.apply(node, n.Qualifier, func(newNode, parent SQLNode) { - parent.(*ColName).Qualifier = newNode.(TableName) - }) - case *CollateExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*CollateExpr).Expr = newNode.(Expr) - }) - case *ColumnDefinition: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*ColumnDefinition).Name = newNode.(ColIdent) - }) - case *ColumnType: - a.apply(node, n.Length, func(newNode, parent SQLNode) { - parent.(*ColumnType).Length = newNode.(*Literal) - }) - a.apply(node, n.Scale, func(newNode, parent SQLNode) { - parent.(*ColumnType).Scale = newNode.(*Literal) - }) - case Columns: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(Columns)[idx] = newNode.(ColIdent) - } - }(x)) - } - case Comments: - case *Commit: - case *ComparisonExpr: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Left = newNode.(Expr) - }) - a.apply(node, n.Right, func(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Right = newNode.(Expr) - }) - a.apply(node, n.Escape, func(newNode, parent SQLNode) { - parent.(*ComparisonExpr).Escape = newNode.(Expr) - }) - case *ConstraintDefinition: - a.apply(node, n.Details, func(newNode, parent SQLNode) { - parent.(*ConstraintDefinition).Details = newNode.(ConstraintInfo) - }) - case *ConvertExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*ConvertExpr).Expr = newNode.(Expr) - }) - a.apply(node, n.Type, func(newNode, parent SQLNode) { - parent.(*ConvertExpr).Type = newNode.(*ConvertType) - }) - case *ConvertType: - a.apply(node, n.Length, func(newNode, parent SQLNode) { - parent.(*ConvertType).Length = newNode.(*Literal) - }) - a.apply(node, n.Scale, func(newNode, parent SQLNode) { - parent.(*ConvertType).Scale = newNode.(*Literal) - }) - case *ConvertUsingExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*ConvertUsingExpr).Expr = newNode.(Expr) - }) - case *CreateDatabase: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*CreateDatabase).Comments = newNode.(Comments) - }) - case *CreateTable: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*CreateTable).Table = newNode.(TableName) - }) - a.apply(node, n.TableSpec, func(newNode, parent SQLNode) { - parent.(*CreateTable).TableSpec = newNode.(*TableSpec) - }) - a.apply(node, n.OptLike, func(newNode, parent SQLNode) { - parent.(*CreateTable).OptLike = newNode.(*OptLike) - }) - case *CreateView: - a.apply(node, n.ViewName, func(newNode, parent SQLNode) { - parent.(*CreateView).ViewName = newNode.(TableName) - }) - a.apply(node, n.Columns, func(newNode, parent SQLNode) { - parent.(*CreateView).Columns = newNode.(Columns) - }) - a.apply(node, n.Select, func(newNode, parent SQLNode) { - parent.(*CreateView).Select = newNode.(SelectStatement) - }) - case *CurTimeFuncExpr: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*CurTimeFuncExpr).Name = newNode.(ColIdent) - }) - a.apply(node, n.Fsp, func(newNode, parent SQLNode) { - parent.(*CurTimeFuncExpr).Fsp = newNode.(Expr) - }) - case *Default: - case *Delete: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Delete).Comments = newNode.(Comments) - }) - a.apply(node, n.Targets, func(newNode, parent SQLNode) { - parent.(*Delete).Targets = newNode.(TableNames) - }) - a.apply(node, n.TableExprs, func(newNode, parent SQLNode) { - parent.(*Delete).TableExprs = newNode.(TableExprs) - }) - a.apply(node, n.Partitions, func(newNode, parent SQLNode) { - parent.(*Delete).Partitions = newNode.(Partitions) - }) - a.apply(node, n.Where, func(newNode, parent SQLNode) { - parent.(*Delete).Where = newNode.(*Where) - }) - a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { - parent.(*Delete).OrderBy = newNode.(OrderBy) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*Delete).Limit = newNode.(*Limit) - }) - case *DerivedTable: - a.apply(node, n.Select, func(newNode, parent SQLNode) { - parent.(*DerivedTable).Select = newNode.(SelectStatement) - }) - case *DropColumn: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*DropColumn).Name = newNode.(*ColName) - }) - case *DropDatabase: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*DropDatabase).Comments = newNode.(Comments) - }) - case *DropKey: - case *DropTable: - a.apply(node, n.FromTables, func(newNode, parent SQLNode) { - parent.(*DropTable).FromTables = newNode.(TableNames) - }) - case *DropView: - a.apply(node, n.FromTables, func(newNode, parent SQLNode) { - parent.(*DropView).FromTables = newNode.(TableNames) - }) - case *ExistsExpr: - a.apply(node, n.Subquery, func(newNode, parent SQLNode) { - parent.(*ExistsExpr).Subquery = newNode.(*Subquery) - }) - case *ExplainStmt: - a.apply(node, n.Statement, func(newNode, parent SQLNode) { - parent.(*ExplainStmt).Statement = newNode.(Statement) - }) - case *ExplainTab: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*ExplainTab).Table = newNode.(TableName) - }) - case Exprs: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(Exprs)[idx] = newNode.(Expr) - } - }(x)) - } - case *Flush: - a.apply(node, n.TableNames, func(newNode, parent SQLNode) { - parent.(*Flush).TableNames = newNode.(TableNames) - }) - case *Force: - case *ForeignKeyDefinition: - a.apply(node, n.Source, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).Source = newNode.(Columns) - }) - a.apply(node, n.ReferencedTable, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).ReferencedTable = newNode.(TableName) - }) - a.apply(node, n.ReferencedColumns, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).ReferencedColumns = newNode.(Columns) - }) - a.apply(node, n.OnDelete, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).OnDelete = newNode.(ReferenceAction) - }) - a.apply(node, n.OnUpdate, func(newNode, parent SQLNode) { - parent.(*ForeignKeyDefinition).OnUpdate = newNode.(ReferenceAction) - }) - case *FuncExpr: - a.apply(node, n.Qualifier, func(newNode, parent SQLNode) { - parent.(*FuncExpr).Qualifier = newNode.(TableIdent) - }) - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*FuncExpr).Name = newNode.(ColIdent) - }) - a.apply(node, n.Exprs, func(newNode, parent SQLNode) { - parent.(*FuncExpr).Exprs = newNode.(SelectExprs) - }) - case GroupBy: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(GroupBy)[idx] = newNode.(Expr) - } - }(x)) - } - case *GroupConcatExpr: - a.apply(node, n.Exprs, func(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).Exprs = newNode.(SelectExprs) - }) - a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).OrderBy = newNode.(OrderBy) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).Limit = newNode.(*Limit) - }) - case *IndexDefinition: - a.apply(node, n.Info, func(newNode, parent SQLNode) { - parent.(*IndexDefinition).Info = newNode.(*IndexInfo) - }) - case *IndexHints: - for x, el := range n.Indexes { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*IndexHints).Indexes[idx] = newNode.(ColIdent) - } - }(x)) - } - case *IndexInfo: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*IndexInfo).Name = newNode.(ColIdent) - }) - a.apply(node, n.ConstraintName, func(newNode, parent SQLNode) { - parent.(*IndexInfo).ConstraintName = newNode.(ColIdent) - }) - case *Insert: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Insert).Comments = newNode.(Comments) - }) - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*Insert).Table = newNode.(TableName) - }) - a.apply(node, n.Partitions, func(newNode, parent SQLNode) { - parent.(*Insert).Partitions = newNode.(Partitions) - }) - a.apply(node, n.Columns, func(newNode, parent SQLNode) { - parent.(*Insert).Columns = newNode.(Columns) - }) - a.apply(node, n.Rows, func(newNode, parent SQLNode) { - parent.(*Insert).Rows = newNode.(InsertRows) - }) - a.apply(node, n.OnDup, func(newNode, parent SQLNode) { - parent.(*Insert).OnDup = newNode.(OnDup) - }) - case *IntervalExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*IntervalExpr).Expr = newNode.(Expr) - }) - case *IsExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*IsExpr).Expr = newNode.(Expr) - }) - case JoinCondition: - a.apply(node, n.On, replacePanic("JoinCondition On")) - a.apply(node, n.Using, replacePanic("JoinCondition Using")) - case *JoinTableExpr: - a.apply(node, n.LeftExpr, func(newNode, parent SQLNode) { - parent.(*JoinTableExpr).LeftExpr = newNode.(TableExpr) - }) - a.apply(node, n.RightExpr, func(newNode, parent SQLNode) { - parent.(*JoinTableExpr).RightExpr = newNode.(TableExpr) - }) - a.apply(node, n.Condition, func(newNode, parent SQLNode) { - parent.(*JoinTableExpr).Condition = newNode.(JoinCondition) - }) - case *KeyState: - case *Limit: - a.apply(node, n.Offset, func(newNode, parent SQLNode) { - parent.(*Limit).Offset = newNode.(Expr) - }) - a.apply(node, n.Rowcount, func(newNode, parent SQLNode) { - parent.(*Limit).Rowcount = newNode.(Expr) - }) - case ListArg: - case *Literal: - case *Load: - case *LockOption: - case *LockTables: - case *MatchExpr: - a.apply(node, n.Columns, func(newNode, parent SQLNode) { - parent.(*MatchExpr).Columns = newNode.(SelectExprs) - }) - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*MatchExpr).Expr = newNode.(Expr) - }) - case *ModifyColumn: - a.apply(node, n.NewColDefinition, func(newNode, parent SQLNode) { - parent.(*ModifyColumn).NewColDefinition = newNode.(*ColumnDefinition) - }) - a.apply(node, n.First, func(newNode, parent SQLNode) { - parent.(*ModifyColumn).First = newNode.(*ColName) - }) - a.apply(node, n.After, func(newNode, parent SQLNode) { - parent.(*ModifyColumn).After = newNode.(*ColName) - }) - case *Nextval: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*Nextval).Expr = newNode.(Expr) - }) - case *NotExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*NotExpr).Expr = newNode.(Expr) - }) - case *NullVal: - case OnDup: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(OnDup)[idx] = newNode.(*UpdateExpr) - } - }(x)) - } - case *OptLike: - a.apply(node, n.LikeTable, func(newNode, parent SQLNode) { - parent.(*OptLike).LikeTable = newNode.(TableName) - }) - case *OrExpr: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*OrExpr).Left = newNode.(Expr) - }) - a.apply(node, n.Right, func(newNode, parent SQLNode) { - parent.(*OrExpr).Right = newNode.(Expr) - }) - case *Order: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*Order).Expr = newNode.(Expr) - }) - case OrderBy: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(OrderBy)[idx] = newNode.(*Order) - } - }(x)) - } - case *OrderByOption: - a.apply(node, n.Cols, func(newNode, parent SQLNode) { - parent.(*OrderByOption).Cols = newNode.(Columns) - }) - case *OtherAdmin: - case *OtherRead: - case *ParenSelect: - a.apply(node, n.Select, func(newNode, parent SQLNode) { - parent.(*ParenSelect).Select = newNode.(SelectStatement) - }) - case *ParenTableExpr: - a.apply(node, n.Exprs, func(newNode, parent SQLNode) { - parent.(*ParenTableExpr).Exprs = newNode.(TableExprs) - }) - case *PartitionDefinition: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*PartitionDefinition).Name = newNode.(ColIdent) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*PartitionDefinition).Limit = newNode.(Expr) - }) - case *PartitionSpec: - a.apply(node, n.Names, func(newNode, parent SQLNode) { - parent.(*PartitionSpec).Names = newNode.(Partitions) - }) - a.apply(node, n.Number, func(newNode, parent SQLNode) { - parent.(*PartitionSpec).Number = newNode.(*Literal) - }) - a.apply(node, n.TableName, func(newNode, parent SQLNode) { - parent.(*PartitionSpec).TableName = newNode.(TableName) - }) - for x, el := range n.Definitions { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*PartitionSpec).Definitions[idx] = newNode.(*PartitionDefinition) - } - }(x)) - } - case Partitions: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(Partitions)[idx] = newNode.(ColIdent) - } - }(x)) - } - case *RangeCond: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*RangeCond).Left = newNode.(Expr) - }) - a.apply(node, n.From, func(newNode, parent SQLNode) { - parent.(*RangeCond).From = newNode.(Expr) - }) - a.apply(node, n.To, func(newNode, parent SQLNode) { - parent.(*RangeCond).To = newNode.(Expr) - }) - case *Release: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*Release).Name = newNode.(ColIdent) - }) - case *RenameIndex: - case *RenameTable: - case *RenameTableName: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*RenameTableName).Table = newNode.(TableName) - }) - case *RevertMigration: - case *Rollback: - case *SRollback: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*SRollback).Name = newNode.(ColIdent) - }) - case *Savepoint: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*Savepoint).Name = newNode.(ColIdent) - }) - case *Select: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Select).Comments = newNode.(Comments) - }) - a.apply(node, n.SelectExprs, func(newNode, parent SQLNode) { - parent.(*Select).SelectExprs = newNode.(SelectExprs) - }) - a.apply(node, n.From, func(newNode, parent SQLNode) { - parent.(*Select).From = newNode.(TableExprs) - }) - a.apply(node, n.Where, func(newNode, parent SQLNode) { - parent.(*Select).Where = newNode.(*Where) - }) - a.apply(node, n.GroupBy, func(newNode, parent SQLNode) { - parent.(*Select).GroupBy = newNode.(GroupBy) - }) - a.apply(node, n.Having, func(newNode, parent SQLNode) { - parent.(*Select).Having = newNode.(*Where) - }) - a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { - parent.(*Select).OrderBy = newNode.(OrderBy) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*Select).Limit = newNode.(*Limit) - }) - a.apply(node, n.Into, func(newNode, parent SQLNode) { - parent.(*Select).Into = newNode.(*SelectInto) - }) - case SelectExprs: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(SelectExprs)[idx] = newNode.(SelectExpr) - } - }(x)) - } - case *SelectInto: - case *Set: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Set).Comments = newNode.(Comments) - }) - a.apply(node, n.Exprs, func(newNode, parent SQLNode) { - parent.(*Set).Exprs = newNode.(SetExprs) - }) - case *SetExpr: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*SetExpr).Name = newNode.(ColIdent) - }) - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*SetExpr).Expr = newNode.(Expr) - }) - case SetExprs: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(SetExprs)[idx] = newNode.(*SetExpr) - } - }(x)) - } - case *SetTransaction: - a.apply(node, n.SQLNode, func(newNode, parent SQLNode) { - parent.(*SetTransaction).SQLNode = newNode.(SQLNode) - }) - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*SetTransaction).Comments = newNode.(Comments) - }) - for x, el := range n.Characteristics { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*SetTransaction).Characteristics[idx] = newNode.(Characteristic) - } - }(x)) - } - case *Show: - a.apply(node, n.Internal, func(newNode, parent SQLNode) { - parent.(*Show).Internal = newNode.(ShowInternal) - }) - case *ShowBasic: - a.apply(node, n.Tbl, func(newNode, parent SQLNode) { - parent.(*ShowBasic).Tbl = newNode.(TableName) - }) - a.apply(node, n.Filter, func(newNode, parent SQLNode) { - parent.(*ShowBasic).Filter = newNode.(*ShowFilter) - }) - case *ShowCreate: - a.apply(node, n.Op, func(newNode, parent SQLNode) { - parent.(*ShowCreate).Op = newNode.(TableName) - }) - case *ShowFilter: - a.apply(node, n.Filter, func(newNode, parent SQLNode) { - parent.(*ShowFilter).Filter = newNode.(Expr) - }) - case *ShowLegacy: - a.apply(node, n.OnTable, func(newNode, parent SQLNode) { - parent.(*ShowLegacy).OnTable = newNode.(TableName) - }) - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*ShowLegacy).Table = newNode.(TableName) - }) - a.apply(node, n.ShowCollationFilterOpt, func(newNode, parent SQLNode) { - parent.(*ShowLegacy).ShowCollationFilterOpt = newNode.(Expr) - }) - case *StarExpr: - a.apply(node, n.TableName, func(newNode, parent SQLNode) { - parent.(*StarExpr).TableName = newNode.(TableName) - }) - case *Stream: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Stream).Comments = newNode.(Comments) - }) - a.apply(node, n.SelectExpr, func(newNode, parent SQLNode) { - parent.(*Stream).SelectExpr = newNode.(SelectExpr) - }) - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*Stream).Table = newNode.(TableName) - }) - case *Subquery: - a.apply(node, n.Select, func(newNode, parent SQLNode) { - parent.(*Subquery).Select = newNode.(SelectStatement) - }) - case *SubstrExpr: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).Name = newNode.(*ColName) - }) - a.apply(node, n.StrVal, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).StrVal = newNode.(*Literal) - }) - a.apply(node, n.From, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).From = newNode.(Expr) - }) - a.apply(node, n.To, func(newNode, parent SQLNode) { - parent.(*SubstrExpr).To = newNode.(Expr) - }) - case TableExprs: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(TableExprs)[idx] = newNode.(TableExpr) - } - }(x)) - } - case TableIdent: - case TableName: - a.apply(node, n.Name, replacePanic("TableName Name")) - a.apply(node, n.Qualifier, replacePanic("TableName Qualifier")) - case TableNames: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(TableNames)[idx] = newNode.(TableName) - } - }(x)) - } - case TableOptions: - case *TableSpec: - for x, el := range n.Columns { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*TableSpec).Columns[idx] = newNode.(*ColumnDefinition) - } - }(x)) - } - for x, el := range n.Indexes { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*TableSpec).Indexes[idx] = newNode.(*IndexDefinition) - } - }(x)) - } - for x, el := range n.Constraints { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*TableSpec).Constraints[idx] = newNode.(*ConstraintDefinition) - } - }(x)) - } - a.apply(node, n.Options, func(newNode, parent SQLNode) { - parent.(*TableSpec).Options = newNode.(TableOptions) - }) - case *TablespaceOperation: - case *TimestampFuncExpr: - a.apply(node, n.Expr1, func(newNode, parent SQLNode) { - parent.(*TimestampFuncExpr).Expr1 = newNode.(Expr) - }) - a.apply(node, n.Expr2, func(newNode, parent SQLNode) { - parent.(*TimestampFuncExpr).Expr2 = newNode.(Expr) - }) - case *TruncateTable: - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*TruncateTable).Table = newNode.(TableName) - }) - case *UnaryExpr: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*UnaryExpr).Expr = newNode.(Expr) - }) - case *Union: - a.apply(node, n.FirstStatement, func(newNode, parent SQLNode) { - parent.(*Union).FirstStatement = newNode.(SelectStatement) - }) - for x, el := range n.UnionSelects { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*Union).UnionSelects[idx] = newNode.(*UnionSelect) - } - }(x)) - } - a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { - parent.(*Union).OrderBy = newNode.(OrderBy) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*Union).Limit = newNode.(*Limit) - }) - case *UnionSelect: - a.apply(node, n.Statement, func(newNode, parent SQLNode) { - parent.(*UnionSelect).Statement = newNode.(SelectStatement) - }) - case *UnlockTables: - case *Update: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*Update).Comments = newNode.(Comments) - }) - a.apply(node, n.TableExprs, func(newNode, parent SQLNode) { - parent.(*Update).TableExprs = newNode.(TableExprs) - }) - a.apply(node, n.Exprs, func(newNode, parent SQLNode) { - parent.(*Update).Exprs = newNode.(UpdateExprs) - }) - a.apply(node, n.Where, func(newNode, parent SQLNode) { - parent.(*Update).Where = newNode.(*Where) - }) - a.apply(node, n.OrderBy, func(newNode, parent SQLNode) { - parent.(*Update).OrderBy = newNode.(OrderBy) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*Update).Limit = newNode.(*Limit) - }) - case *UpdateExpr: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*UpdateExpr).Name = newNode.(*ColName) - }) - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*UpdateExpr).Expr = newNode.(Expr) - }) - case UpdateExprs: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(UpdateExprs)[idx] = newNode.(*UpdateExpr) - } - }(x)) - } - case *Use: - a.apply(node, n.DBName, func(newNode, parent SQLNode) { - parent.(*Use).DBName = newNode.(TableIdent) - }) - case *VStream: - a.apply(node, n.Comments, func(newNode, parent SQLNode) { - parent.(*VStream).Comments = newNode.(Comments) - }) - a.apply(node, n.SelectExpr, func(newNode, parent SQLNode) { - parent.(*VStream).SelectExpr = newNode.(SelectExpr) - }) - a.apply(node, n.Table, func(newNode, parent SQLNode) { - parent.(*VStream).Table = newNode.(TableName) - }) - a.apply(node, n.Where, func(newNode, parent SQLNode) { - parent.(*VStream).Where = newNode.(*Where) - }) - a.apply(node, n.Limit, func(newNode, parent SQLNode) { - parent.(*VStream).Limit = newNode.(*Limit) - }) - case ValTuple: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(ValTuple)[idx] = newNode.(Expr) - } - }(x)) - } - case *Validation: - case Values: - for x, el := range n { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(Values)[idx] = newNode.(ValTuple) - } - }(x)) - } - case *ValuesFuncExpr: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*ValuesFuncExpr).Name = newNode.(*ColName) - }) - case VindexParam: - a.apply(node, n.Key, replacePanic("VindexParam Key")) - case *VindexSpec: - a.apply(node, n.Name, func(newNode, parent SQLNode) { - parent.(*VindexSpec).Name = newNode.(ColIdent) - }) - a.apply(node, n.Type, func(newNode, parent SQLNode) { - parent.(*VindexSpec).Type = newNode.(ColIdent) - }) - for x, el := range n.Params { - a.apply(node, el, func(idx int) func(SQLNode, SQLNode) { - return func(newNode, container SQLNode) { - container.(*VindexSpec).Params[idx] = newNode.(VindexParam) - } - }(x)) - } - case *When: - a.apply(node, n.Cond, func(newNode, parent SQLNode) { - parent.(*When).Cond = newNode.(Expr) - }) - a.apply(node, n.Val, func(newNode, parent SQLNode) { - parent.(*When).Val = newNode.(Expr) - }) - case *Where: - a.apply(node, n.Expr, func(newNode, parent SQLNode) { - parent.(*Where).Expr = newNode.(Expr) - }) - case *XorExpr: - a.apply(node, n.Left, func(newNode, parent SQLNode) { - parent.(*XorExpr).Left = newNode.(Expr) - }) - a.apply(node, n.Right, func(newNode, parent SQLNode) { - parent.(*XorExpr).Right = newNode.(Expr) - }) - } - if a.post != nil && !a.post(&a.cursor) { - panic(abort) - } - a.cursor = saved -} diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go index ea25e67b1d6..cd4f3d957dc 100644 --- a/go/vt/sqlparser/rewriter_api.go +++ b/go/vt/sqlparser/rewriter_api.go @@ -17,8 +17,7 @@ limitations under the License. package sqlparser import ( - "reflect" - "runtime" + "fmt" ) // The rewriter was heavily inspired by https://github.com/golang/tools/blob/master/go/ast/astutil/rewrite.go @@ -41,34 +40,21 @@ import ( // func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode, err error) { parent := &struct{ SQLNode }{node} - defer func() { - if r := recover(); r != nil { - switch r := r.(type) { - case abortT: // nothing to do - - case *runtime.TypeAssertionError: - err = r - case *valueTypeFieldCantChangeErr: - err = r - default: - panic(r) - } - } - result = parent.SQLNode - }() - - a := &application{ - pre: pre, - post: post, - cursor: Cursor{}, - } // this is the root-replacer, used when the user replaces the root of the ast replacer := func(newNode SQLNode, _ SQLNode) { parent.SQLNode = newNode } - a.apply(parent, node, replacer) + a := &application{ + pre: pre, + post: post, + } + + err = a.rewriteSQLNode(parent, node, replacer) + if err != nil && err != errAbort { + return nil, err + } return parent.SQLNode, nil } @@ -81,9 +67,7 @@ func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode, err error) { // See Rewrite for details. type ApplyFunc func(*Cursor) bool -type abortT int - -var abort = abortT(0) // singleton, to signal termination of Apply +var errAbort = fmt.Errorf("this error is to abort the rewriter, it is not an actual error") // A Cursor describes a node encountered during Apply. // Information about the node and its parent is available @@ -112,28 +96,5 @@ type replacerFunc func(newNode, parent SQLNode) // application carries all the shared data so we can pass it around cheaply. type application struct { pre, post ApplyFunc - cursor Cursor -} - -func isNilValue(i interface{}) bool { - valueOf := reflect.ValueOf(i) - kind := valueOf.Kind() - isNullable := kind == reflect.Ptr || kind == reflect.Array || kind == reflect.Slice - return isNullable && valueOf.IsNil() -} - -// this type is here so we can catch it in the Rewrite method above -type valueTypeFieldCantChangeErr struct { - msg string -} - -// Error implements the error interface -func (e *valueTypeFieldCantChangeErr) Error() string { - return "Tried replacing a field of a value type. This is not supported. " + e.msg -} - -func replacePanic(msg string) func(newNode, parent SQLNode) { - return func(newNode, parent SQLNode) { - panic(&valueTypeFieldCantChangeErr{msg: msg}) - } + cur Cursor } diff --git a/go/vt/sqlparser/rewriter_test.go b/go/vt/sqlparser/rewriter_test.go index 6131c6c5588..4a037aeead5 100644 --- a/go/vt/sqlparser/rewriter_test.go +++ b/go/vt/sqlparser/rewriter_test.go @@ -39,19 +39,6 @@ func BenchmarkVisitLargeExpression(b *testing.B) { } } -func TestBadTypeReturnsErrorAndNotPanic(t *testing.T) { - parse, err := Parse("select 42 from dual") - require.NoError(t, err) - _, err = Rewrite(parse, func(cursor *Cursor) bool { - _, ok := cursor.Node().(*Literal) - if ok { - cursor.Replace(&AliasedTableExpr{}) // this is not a valid replacement because of types - } - return true - }, nil) - require.Error(t, err) -} - func TestChangeValueTypeGivesError(t *testing.T) { parse, err := Parse("select * from a join b on a.id = b.id") require.NoError(t, err) diff --git a/go/vt/sqlparser/walker_test.go b/go/vt/sqlparser/walker_test.go index c30741029be..ec7727a0832 100644 --- a/go/vt/sqlparser/walker_test.go +++ b/go/vt/sqlparser/walker_test.go @@ -38,3 +38,24 @@ func BenchmarkWalkLargeExpression(b *testing.B) { }) } } + +func BenchmarkRewriteLargeExpression(b *testing.B) { + for i := 1; i < 7; i++ { + b.Run(fmt.Sprintf("%d", i), func(b *testing.B) { + exp := newGenerator(int64(i*100), i).expression() + count := 0 + for i := 0; i < b.N; i++ { + _, err := Rewrite(exp, func(_ *Cursor) bool { + count++ + return true + }, func(_ *Cursor) bool { + count-- + return true + }) + if err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/misc/git/hooks/visitorgen b/misc/git/hooks/asthelpers similarity index 100% rename from misc/git/hooks/visitorgen rename to misc/git/hooks/asthelpers