Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2674,7 +2674,7 @@ func (ct *ColumnType) merge(other ColumnType) error {

if other.KeyOpt != colKeyNone {
keyOptions := []ColumnKeyOption{ct.KeyOpt, other.KeyOpt}
sort.Slice(keyOptions, func(i, j int) bool {return keyOptions[i] < keyOptions[j]})
sort.Slice(keyOptions, func(i, j int) bool { return keyOptions[i] < keyOptions[j] })
if other.KeyOpt == ct.KeyOpt {
// MySQL will deduplicate key options when they are repeated.
} else if keyOptions[0] == colKeyPrimary && (keyOptions[1] == colKeyUnique || keyOptions[1] == colKeyUniqueKey) {
Expand Down Expand Up @@ -4789,7 +4789,7 @@ func (*ConvertExpr) iExpr() {}
func (*SubstrExpr) iExpr() {}
func (*TrimExpr) iExpr() {}
func (*ConvertUsingExpr) iExpr() {}
func (*CharExpr) iExpr() {}
func (*CharExpr) iExpr() {}
func (*MatchExpr) iExpr() {}
func (*GroupConcatExpr) iExpr() {}
func (*Default) iExpr() {}
Expand Down Expand Up @@ -6895,6 +6895,15 @@ func (node *TableFuncExpr) UnmarshalJSON(b []byte) error {
return nil
}

func (node *TableFuncExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Exprs)
}

// TableIdent is a case sensitive SQL identifier. It will be escaped with
// backquotes if necessary.
type TableIdent struct {
Expand Down
49 changes: 33 additions & 16 deletions go/vt/sqlparser/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,22 +323,39 @@ func TestDropIndex(t *testing.T) {
testIndex(t, tests)
}

// TestShowTablePrepared tests that Vitess can correctly walk all the SQLVal instances
// in a parsed SHOW TABLES statement to identify the bound variables.
func TestShowTablePrepared(t *testing.T) {
statement, err := Parse("SHOW TABLES FROM `mydb` WHERE `Tables_in_mydb` = ?")
require.NoError(t, err)
paramsCount := uint16(0)
_ = Walk(func(node SQLNode) (bool, error) {
switch node := node.(type) {
case *SQLVal:
if strings.HasPrefix(string(node.Val), ":v") {
paramsCount++
}
}
return true, nil
}, statement)
assert.Equal(t, uint16(1), paramsCount)
// TestWalkPrepared tests that Vitess can correctly walk all the SQLVal instances
// in a parsed statement to identify the bound variables.
func TestWalkPrepared(t *testing.T) {
tests := []struct {
q string
paramCnt int
}{
{
q: "SHOW TABLES FROM `mydb` WHERE `Tables_in_mydb` = ?",
paramCnt: 1,
},
{
q: "select * from table_function(?,?)",
paramCnt: 2,
},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("prepare param count: %s", tt.q), func(t *testing.T) {
statement, err := Parse(tt.q)
require.NoError(t, err)
paramsCount := uint16(0)
_ = Walk(func(node SQLNode) (bool, error) {
switch node := node.(type) {
case *SQLVal:
if strings.HasPrefix(string(node.Val), ":v") {
paramsCount++
}
}
return true, nil
}, statement)
assert.Equal(t, uint16(tt.paramCnt), paramsCount)
})
}
}

func TestShowIndex(t *testing.T) {
Expand Down